mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
2348 lines
99 KiB
2348 lines
99 KiB
3 years ago
|
'''
|
||
|
Copyright 2019 The Microsoft DeepSpeed Team
|
||
|
'''
|
||
|
|
||
|
import math
|
||
|
|
||
|
import torch
|
||
|
import torch.distributed as dist
|
||
3 years ago
|
|
||
3 years ago
|
try:
|
||
|
from deepspeed.git_version_info import version
|
||
|
from deepspeed.moe.utils import is_moe_param
|
||
|
from deepspeed.ops.adam import DeepSpeedCPUAdam
|
||
|
from deepspeed.ops.op_builder import UtilsBuilder
|
||
|
from deepspeed.runtime.zero.config import ZERO_OPTIMIZATION_GRADIENTS
|
||
|
except ImportError:
|
||
3 years ago
|
pass
|
||
3 years ago
|
from packaging import version as pkg_version
|
||
|
from torch._six import inf
|
||
|
from torch.distributed.distributed_c10d import _get_global_rank
|
||
|
from torch.optim import Optimizer
|
||
|
|
||
|
from colossalai.core import global_context as gpc
|
||
|
from colossalai.utils import report_memory_usage
|
||
3 years ago
|
from colossalai.utils.common import is_model_parallel_parameter
|
||
3 years ago
|
from .loss_scaler import LossScaler, DynamicLossScaler
|
||
3 years ago
|
from colossalai.context import ParallelMode
|
||
3 years ago
|
|
||
|
# Toggle this to true to enable correctness test
|
||
|
# with gradient partitioning and without
|
||
|
pg_correctness_test = False
|
||
|
|
||
|
|
||
|
def input(msg):
|
||
|
return
|
||
|
|
||
|
|
||
|
def split_half_float_double(tensors):
|
||
|
dtypes = [
|
||
|
"torch.cuda.HalfTensor",
|
||
|
"torch.cuda.FloatTensor",
|
||
|
"torch.cuda.DoubleTensor"
|
||
|
]
|
||
|
buckets = []
|
||
|
for i, dtype in enumerate(dtypes):
|
||
|
bucket = [t for t in tensors if t.type() == dtype]
|
||
|
if bucket:
|
||
|
buckets.append(bucket)
|
||
|
return buckets
|
||
|
|
||
|
|
||
|
def isclose(a, b, rtol=1e-09, atol=0.0):
|
||
|
return abs(a - b) <= max(rtol * max(abs(a), abs(b)), atol)
|
||
|
|
||
|
|
||
|
def lcm(x, y):
|
||
|
from fractions import gcd # or can import gcd from `math` in Python 3
|
||
|
return x * y // gcd(x, y)
|
||
|
|
||
|
|
||
|
def get_alignment_padding(tensor_list, alignment):
|
||
|
num_elements = sum([tensor.numel() for tensor in tensor_list])
|
||
|
remainder = num_elements % alignment
|
||
|
return (alignment - remainder) if remainder else remainder
|
||
|
|
||
|
|
||
|
def move_to_cpu(tensor_list):
|
||
|
for tensor in tensor_list:
|
||
|
tensor.data = tensor.data.cpu()
|
||
|
|
||
|
|
||
|
def print_rank_msg(msg):
|
||
|
print(f"rank {dist.get_rank()} - {msg}")
|
||
|
|
||
|
|
||
|
class ZeroRedundancyOptimizer_Level_2(Optimizer):
|
||
|
"""
|
||
|
ZeroRedundancyOptimizer_Level_2 designed to reduce the memory footprint
|
||
|
required for training large deep learning models.
|
||
|
|
||
|
For more details please see ZeRO: Memory Optimization Towards Training A Trillion Parameter Models
|
||
|
https://arxiv.org/abs/1910.02054
|
||
|
|
||
|
"""
|
||
|
|
||
|
def __init__(self,
|
||
|
init_optimizer,
|
||
|
dp_parallel_mode=ParallelMode.DATA,
|
||
|
static_loss_scale=1.0,
|
||
|
dynamic_loss_scale=False,
|
||
|
dynamic_loss_args=None,
|
||
|
verbose=False,
|
||
|
contiguous_gradients=True,
|
||
|
reduce_bucket_size=500000000,
|
||
|
allgather_bucket_size=5000000000,
|
||
|
reduce_scatter=True,
|
||
|
overlap_comm=False,
|
||
|
cpu_offload=False,
|
||
|
clip_grad=0.0,
|
||
|
allreduce_always_fp32=False,
|
||
|
postscale_gradients=True,
|
||
|
gradient_predivide_factor=1.0,
|
||
|
gradient_accumulation_steps=1,
|
||
|
ignore_unused_parameters=True,
|
||
|
round_robin_gradients=False,
|
||
|
fp16_master_weights_and_gradients=False):
|
||
|
# mpu = None is removed from the parameter list
|
||
|
# tensor parallel will be automatically detected later
|
||
|
|
||
|
# LSG: default arguments for compatibility
|
||
|
has_moe_layers = False
|
||
|
partition_grads = True
|
||
|
expert_parallel_group = None
|
||
|
expert_data_parallel_group = None
|
||
|
self.timers = None
|
||
|
self.defaults = init_optimizer.defaults
|
||
|
|
||
|
dp_process_group = gpc.get_group(dp_parallel_mode)
|
||
|
if gpc.get_world_size(dp_parallel_mode) == 1:
|
||
|
partition_grads = False # for compatibility with dp size = 1
|
||
|
|
||
|
self.verbose = verbose
|
||
|
|
||
|
if dist.get_rank() == 0 and self.verbose:
|
||
|
print(f"Reduce bucket size {reduce_bucket_size}")
|
||
|
print(f"Allgather bucket size {allgather_bucket_size}")
|
||
|
print(f"CPU Offload: {cpu_offload}")
|
||
|
print(
|
||
|
f'Round robin gradient partitioning: {round_robin_gradients}')
|
||
|
# The fused optimizer does all the work. We need this layer for two reason:
|
||
|
# 1. maintain same user API from apex.fp16_utils
|
||
|
# 2. keep common stuff here in case we need to add ne552w fused optimizer later
|
||
|
|
||
|
# differences from apex.fp16_utils:
|
||
|
# - assume all model params in fp16
|
||
|
# - assume all params requires grad
|
||
|
# - flat by groups, not keeping state. TODO: remove state explicitly?
|
||
|
# - master gard and unflat master weight never exist. TODO: a way to save out unflat master?
|
||
|
if not torch.cuda.is_available:
|
||
|
raise SystemError("Cannot use fp16 without CUDA.")
|
||
|
self.optimizer = init_optimizer
|
||
|
|
||
|
# Load pre-built or JIT compile (un)flatten ops
|
||
|
util_ops = UtilsBuilder().load()
|
||
|
self.flatten = util_ops.flatten
|
||
|
self.unflatten = util_ops.unflatten
|
||
|
|
||
|
# ZeRO stage 1 (False) or 2 (True)
|
||
|
self.partition_gradients = partition_grads
|
||
|
|
||
|
self.reduce_scatter = reduce_scatter
|
||
|
|
||
|
self.overlap_comm = overlap_comm
|
||
|
|
||
|
self.cpu_offload = cpu_offload
|
||
|
|
||
|
self.deepspeed_adam_offload = cpu_offload
|
||
|
|
||
|
self.device = torch.cuda.current_device() if not self.cpu_offload else 'cpu'
|
||
|
|
||
|
self.dp_process_group = dp_process_group
|
||
|
|
||
|
# expert parallel group
|
||
|
self.ep_process_group = expert_parallel_group
|
||
|
|
||
|
# data parallel group for experts
|
||
|
self.expert_dp_process_group = expert_data_parallel_group
|
||
|
|
||
|
# data parallel size for non-experts
|
||
|
dp_size = dist.get_world_size(group=self.dp_process_group)
|
||
|
|
||
|
# For MoE models this maybe different for different param group
|
||
|
# It will be modified during MoE setup later in the init
|
||
|
self.real_dp_process_group = [
|
||
|
dp_process_group for i in range(len(self.optimizer.param_groups))
|
||
|
]
|
||
|
self.partition_count = [dp_size for i in range(
|
||
|
len(self.optimizer.param_groups))]
|
||
|
|
||
|
self.is_gradient_accumulation_boundary = True
|
||
|
|
||
|
# CPU-Offload requires contiguous gradients
|
||
|
self.contiguous_gradients = contiguous_gradients or cpu_offload
|
||
|
|
||
|
self.has_moe_layers = has_moe_layers
|
||
|
|
||
|
if self.has_moe_layers:
|
||
|
self._configure_moe_settings()
|
||
|
|
||
|
if not gpc.is_initialized(ParallelMode.TENSOR) or gpc.get_world_size(ParallelMode.TENSOR) == 1:
|
||
|
self.model_parallel_group = None
|
||
|
self.model_parallel_rank = 0
|
||
|
else:
|
||
|
self.model_parallel_group = gpc.get_group(ParallelMode.TENSOR)
|
||
|
self.model_parallel_rank = gpc.get_local_rank(ParallelMode.TENSOR)
|
||
|
|
||
|
self.overflow = False
|
||
|
self.clip_grad = clip_grad
|
||
|
self.allreduce_always_fp32 = allreduce_always_fp32
|
||
|
self.gradient_predivide_factor = gradient_predivide_factor
|
||
|
self.postscale_gradients = postscale_gradients
|
||
|
self.gradient_accumulation_steps = gradient_accumulation_steps
|
||
|
self.micro_step_id = 0
|
||
|
self.ignore_unused_parameters = ignore_unused_parameters
|
||
|
self.round_robin_gradients = round_robin_gradients
|
||
|
|
||
|
self.extra_large_param_to_reduce = None
|
||
|
self.fp16_master_weights_and_gradients = fp16_master_weights_and_gradients
|
||
|
|
||
|
if self.fp16_master_weights_and_gradients:
|
||
|
assert self.cpu_offload and type(self.optimizer) in [
|
||
|
DeepSpeedCPUAdam], f"fp16_master_and_gradients requires optimizer to support keeping fp16 master and gradients while keeping the optimizer states in fp32. Currenty only supported using ZeRO-Offload with DeepSpeedCPUAdam. But current setting is ZeRO-Offload:{self.cpu_offload} and optimizer type {type(self.optimizer)}. Either disable fp16_master_weights_and_gradients or enable ZeRO-2 Offload with DeepSpeedCPUAdam"
|
||
|
|
||
|
if self.reduce_scatter:
|
||
|
assert not self.allreduce_always_fp32, "allreduce_always_fp32 is not yet supported with ZeRO-2 with reduce scatter enabled"
|
||
|
assert self.gradient_predivide_factor == 1.0, "gradient_predivide_factor != 1.0 is not yet supported with ZeRO-2 with reduce scatter enabled"
|
||
|
assert self.postscale_gradients, "pre-scale gradients is not yet supported with ZeRO-2 with reduce scatter enabled"
|
||
|
|
||
|
# param flattened by groups
|
||
|
self.fp16_groups = []
|
||
|
self.fp16_groups_flat = []
|
||
|
|
||
|
# param partitioned by data parallel degree
|
||
|
# this will contain a list of equal sized tensors
|
||
|
# each of which will be updated by a different process
|
||
|
self.parallel_partitioned_fp16_groups = []
|
||
|
|
||
|
# a single 32-bit partition of the parallel partitioned parameters
|
||
|
# that this process will update
|
||
|
self.single_partition_of_fp32_groups = []
|
||
|
|
||
|
# param partition info
|
||
|
|
||
|
# These are the parameters in each group that will not be updated by this process directly
|
||
|
self.params_not_in_partition = []
|
||
|
|
||
|
# These are the parameters that will be updated by this process directly
|
||
|
self.params_in_partition = []
|
||
|
|
||
|
# Offset from the first paramter in the the self.params_in_partition
|
||
|
# the parameter boundaries may not align with partition boundaries
|
||
|
# so we need to keep track of the offset
|
||
|
self.first_offset = []
|
||
|
|
||
|
# number of elements per partition in each group
|
||
|
self.partition_size = []
|
||
|
|
||
|
# align nccl all-gather send buffers to 4-bye boundary
|
||
|
# 4-byte alignment/sizeof(fp16) = 2
|
||
|
self.nccl_start_alignment_factor = 2
|
||
|
|
||
|
assert (
|
||
3 years ago
|
allgather_bucket_size % self.nccl_start_alignment_factor == 0), f"allgather_bucket_size must be a multiple of nccl_start_alignment_factor, {self.nccl_start_alignment_factor} "
|
||
3 years ago
|
|
||
|
self.all_reduce_print = False
|
||
|
self.dtype = self.optimizer.param_groups[0]['params'][0].dtype
|
||
|
|
||
|
self.round_robin_fp16_groups = []
|
||
|
self.round_robin_fp6_indices = []
|
||
|
|
||
|
# padding on each partition for alignment purposes
|
||
|
self.groups_padding = []
|
||
|
# loop to deal with groups
|
||
|
for i, param_group in enumerate(self.optimizer.param_groups):
|
||
|
partition_id = dist.get_rank(group=self.real_dp_process_group[i])
|
||
|
|
||
|
# push this group to list before modify
|
||
|
# TODO: Explore simplification that avoids the extra book-keeping by pushing the reordered group
|
||
|
self.fp16_groups.append(param_group['params'])
|
||
|
|
||
|
# Record padding required to align group to world size
|
||
|
if partition_id == dist.get_world_size(
|
||
|
group=self.real_dp_process_group[i]) - 1:
|
||
|
padding = get_alignment_padding(self.fp16_groups[i],
|
||
|
self.partition_count[i])
|
||
|
else:
|
||
|
padding = 0
|
||
|
self.groups_padding.append(padding)
|
||
|
|
||
|
# not sure why apex was cloning the weights before flattening
|
||
|
# removing cloning here
|
||
|
|
||
|
if self.verbose:
|
||
|
report_memory_usage(f"Before moving param group {i} to CPU")
|
||
|
# move all the parameters to cpu to free up GPU space for creating flat buffer
|
||
|
move_to_cpu(self.fp16_groups[i])
|
||
|
if self.verbose:
|
||
|
report_memory_usage(f"After moving param group {i} to CPU")
|
||
|
|
||
|
# Reorder group parameters for load balancing of gradient partitioning during backward among ranks.
|
||
|
# This ensures that gradients are reduced in a fashion such that ownership round robins among the ranks.
|
||
|
# For example, rather than 3 gradients (g_n+2, g_n+1, g_n) that are reduced consecutively belonging
|
||
|
# to the same rank, instead they will belong to 3 ranks (r_m+2, r_m+1, r_m).
|
||
|
if self.round_robin_gradients:
|
||
|
round_robin_tensors, round_robin_indices = self._round_robin_reorder(
|
||
|
self.fp16_groups[i],
|
||
|
dist.get_world_size(group=self.real_dp_process_group[i])
|
||
|
)
|
||
|
else:
|
||
|
round_robin_tensors = self.fp16_groups[i]
|
||
|
round_robin_indices = list(range(len(self.fp16_groups[i])))
|
||
|
|
||
|
self.round_robin_fp16_groups.append(round_robin_tensors)
|
||
|
self.round_robin_fp6_indices.append(round_robin_indices)
|
||
|
|
||
|
# create flat buffer in CPU and move to GPU
|
||
|
self.fp16_groups_flat.append(
|
||
|
self.flatten_dense_tensors_aligned(
|
||
|
self.round_robin_fp16_groups[i],
|
||
|
self.nccl_start_alignment_factor *
|
||
|
dist.get_world_size(group=self.real_dp_process_group[i])).cuda(
|
||
|
torch.cuda.current_device()))
|
||
|
|
||
|
if self.verbose:
|
||
|
report_memory_usage(
|
||
|
f"After flattening and moving param group {i} to GPU")
|
||
|
|
||
|
if dist.get_rank(group=self.real_dp_process_group[i]) == 0:
|
||
|
report_memory_usage(
|
||
|
f"After Flattening and after emptying param group {i} cache")
|
||
|
|
||
|
# set model fp16 weight to slices of flattened buffer
|
||
|
self._update_model_fp16_weights(i)
|
||
|
|
||
|
# divide the flat weights into near equal partition equal to the data parallel degree
|
||
|
# each process will compute on a different part of the partition
|
||
|
data_parallel_partitions = self.get_data_parallel_partitions(
|
||
|
self.fp16_groups_flat[i],
|
||
|
i)
|
||
|
self.parallel_partitioned_fp16_groups.append(
|
||
|
data_parallel_partitions)
|
||
|
|
||
|
# verify that data partition start locations are 4-byte aligned
|
||
|
for partitioned_data in data_parallel_partitions:
|
||
|
assert (partitioned_data.data_ptr() %
|
||
|
(2 * self.nccl_start_alignment_factor) == 0)
|
||
|
|
||
|
# a partition of the fp32 master weights that will be updated by this process
|
||
|
if not fp16_master_weights_and_gradients:
|
||
|
self.single_partition_of_fp32_groups.append(
|
||
|
self.parallel_partitioned_fp16_groups[i][partition_id].to(
|
||
|
self.device).clone().float().detach())
|
||
|
else:
|
||
|
self.single_partition_of_fp32_groups.append(
|
||
|
self.parallel_partitioned_fp16_groups[i][partition_id].to(
|
||
|
self.device).clone().half().detach())
|
||
|
|
||
|
# modify optimizer of have flat master weight
|
||
|
self.single_partition_of_fp32_groups[
|
||
|
i].requires_grad = True # keep this in case internal optimizer uses it
|
||
|
param_group['params'] = [self.single_partition_of_fp32_groups[i]]
|
||
|
|
||
|
partition_size = len(self.fp16_groups_flat[i]) / dist.get_world_size(
|
||
|
group=self.real_dp_process_group[i])
|
||
|
params_in_partition, params_not_in_partition, first_offset = self.get_partition_info(
|
||
|
self.round_robin_fp16_groups[i],
|
||
|
partition_size,
|
||
|
partition_id)
|
||
|
|
||
|
self.partition_size.append(partition_size)
|
||
|
self.params_in_partition.append(params_in_partition)
|
||
|
self.params_not_in_partition.append(params_not_in_partition)
|
||
|
self.first_offset.append(first_offset)
|
||
|
|
||
|
for rank in range(dist.get_world_size()):
|
||
|
if dist.get_rank() == rank and self.verbose:
|
||
|
print(
|
||
|
f"Rank: {rank} partition count {self.partition_count} and sizes{[(p.numel(), self.is_moe_param_group[i] if hasattr(self, 'is_moe_param_group') else False) for i, p in enumerate(self.single_partition_of_fp32_groups)]} "
|
||
|
)
|
||
|
dist.barrier()
|
||
|
# exit(0)
|
||
|
self.reduce_bucket_size = int(reduce_bucket_size)
|
||
|
self.allgather_bucket_size = int(allgather_bucket_size)
|
||
|
|
||
|
self.reduction_event = torch.cuda.Event(
|
||
|
enable_timing=False, blocking=False)
|
||
|
self.reduction_stream = torch.cuda.Stream()
|
||
|
self.cpu_computation_stream = torch.cuda.Stream()
|
||
|
self.copy_grad_stream = torch.cuda.Stream()
|
||
|
self.callback_queued = False
|
||
|
|
||
|
self.param_dict = {}
|
||
|
|
||
|
# map between param_id and bool to specify if a param is in this partition
|
||
|
self.is_param_in_current_partition = {}
|
||
|
|
||
|
self.grads_in_ipg_bucket = []
|
||
|
self.params_in_ipg_bucket = []
|
||
|
self.elements_in_ipg_bucket = 0
|
||
|
self.params_already_reduced = []
|
||
|
self._release_ipg_buffers()
|
||
|
self.previous_reduced_grads = None
|
||
|
self.ipg_bucket_has_moe_params = False
|
||
|
|
||
|
# simplified param id
|
||
|
self.param_id = {}
|
||
|
|
||
|
largest_param_numel = 0
|
||
|
count = 0
|
||
|
for i, params_group in enumerate(self.fp16_groups):
|
||
|
for param in params_group:
|
||
|
unique_id = id(param)
|
||
|
self.param_id[unique_id] = count
|
||
|
self.param_dict[count] = param
|
||
|
self.params_already_reduced.append(False)
|
||
|
if param.numel() > largest_param_numel:
|
||
|
largest_param_numel = param.numel()
|
||
|
count = count + 1
|
||
|
|
||
|
for param_group in self.params_in_partition:
|
||
|
for param in param_group:
|
||
|
self.is_param_in_current_partition[self.get_param_id(
|
||
|
param)] = True
|
||
|
|
||
|
for param_group in self.params_not_in_partition:
|
||
|
for param in param_group:
|
||
|
self.is_param_in_current_partition[self.get_param_id(
|
||
|
param)] = False
|
||
|
|
||
|
if self.cpu_offload:
|
||
|
self.accumulated_grads_in_cpu = {}
|
||
|
self.norm_for_param_grads = {}
|
||
|
self.local_overflow = False
|
||
|
self.grad_position = {}
|
||
|
self.temp_grad_buffer_for_cpu_offload = torch.zeros(
|
||
|
largest_param_numel,
|
||
|
device=self.device,
|
||
|
dtype=self.dtype).pin_memory()
|
||
|
self.temp_grad_buffer_for_gpu_offload = torch.zeros(
|
||
|
largest_param_numel,
|
||
|
device=torch.cuda.current_device(),
|
||
|
dtype=self.dtype)
|
||
|
|
||
|
for i, params_group in enumerate(self.fp16_groups):
|
||
|
self.get_grad_position(i,
|
||
|
self.params_in_partition[i],
|
||
|
self.first_offset[i],
|
||
|
self.partition_size[i])
|
||
|
|
||
|
# mapping from parameter to partition that it belongs to
|
||
|
self.param_to_partition_ids = {}
|
||
|
|
||
|
# stores if a partition has been reduced in this step
|
||
|
self.is_partition_reduced = {}
|
||
|
|
||
|
# number of grads in partition that still need to be computed
|
||
|
self.remaining_grads_in_partition = {}
|
||
|
|
||
|
# total number of grads in partition
|
||
|
self.total_grads_in_partition = {}
|
||
|
|
||
|
# stores if a grad in a partition has been computed or not
|
||
|
self.is_grad_computed = {}
|
||
|
|
||
|
# stores the offset at which a parameter gradient needs to be inserted in a partition
|
||
|
self.grad_partition_insertion_offset = {}
|
||
|
|
||
|
# the offset in the gradient at which it must be inserted at the beginning of the partition
|
||
|
self.grad_start_offset = {}
|
||
|
|
||
|
# will store the averaged gradients required by this partition
|
||
|
self.averaged_gradients = {}
|
||
|
|
||
|
# store index of first parameter in each partition
|
||
|
self.first_param_index_in_partition = {}
|
||
|
|
||
|
# initializes all data structures for implementing gradient partitioning
|
||
|
self.initialize_gradient_partitioning_data_structures()
|
||
|
|
||
|
# resets the data structure value for the next backward propagation
|
||
|
self.reset_partition_gradient_structures()
|
||
|
|
||
|
# creates backward hooks for gradient partitioning
|
||
|
if self.partition_gradients or self.overlap_comm:
|
||
|
self.create_reduce_and_remove_grad_hooks()
|
||
|
|
||
|
# we may have a way of fusing dynamic scale. Do not support for now
|
||
|
if self.dtype == torch.float or not dynamic_loss_scale:
|
||
|
loss_scale_value = 1.0 if self.dtype == torch.float else static_loss_scale
|
||
|
|
||
|
self.dynamic_loss_scale = False
|
||
|
self.loss_scaler = LossScaler(scale=loss_scale_value)
|
||
|
cur_iter = 0
|
||
|
else:
|
||
|
if dynamic_loss_args is None:
|
||
|
self.loss_scaler = DynamicLossScaler()
|
||
|
else:
|
||
|
self.loss_scaler = DynamicLossScaler(**dynamic_loss_args)
|
||
|
|
||
|
self.dynamic_loss_scale = True
|
||
|
|
||
|
if self.verbose:
|
||
|
report_memory_usage("Before initializing optimizer states")
|
||
|
self.initialize_optimizer_states()
|
||
|
if self.verbose:
|
||
|
report_memory_usage("After initializing optimizer states")
|
||
|
|
||
|
if dist.get_rank() == 0:
|
||
|
print(f"optimizer state initialized")
|
||
|
|
||
|
if dist.get_rank(group=self.dp_process_group) == 0:
|
||
|
report_memory_usage(f"After initializing ZeRO optimizer")
|
||
|
|
||
|
def _configure_moe_settings(self):
|
||
|
assert self.contiguous_gradients, "Contiguous Gradients in ZeRO Stage 2 must be set to True for MoE. Other code paths are not tested with MoE"
|
||
|
assert self.reduce_scatter, "Reduce Scatter in ZeRO Stage 2 must be set to True for MoE. Other code paths are not tested with MoE"
|
||
|
|
||
|
def is_moe_group(group):
|
||
|
return 'moe' in group and group['moe']
|
||
|
|
||
|
assert any([is_moe_group(group) for group in
|
||
|
self.optimizer.param_groups]), "The model has moe layers, but None of the param groups are marked as MoE. Create a param group with 'moe' key set to True before creating optimizer"
|
||
|
self.is_moe_param_group = []
|
||
|
for i, group in enumerate(self.optimizer.param_groups):
|
||
|
if is_moe_group(group):
|
||
|
assert all(
|
||
|
[is_moe_param(param) for param in group['params']]), "All params in MoE group must be MoE params"
|
||
|
self.real_dp_process_group[i] = self.expert_dp_process_group
|
||
|
self.partition_count[i] = dist.get_world_size(
|
||
|
group=self.expert_dp_process_group)
|
||
|
self.is_moe_param_group.append(True)
|
||
|
else:
|
||
|
self.is_moe_param_group.append(False)
|
||
|
|
||
|
assert self.expert_dp_process_group is not None, "Expert data parallel group should be configured with MoE"
|
||
|
assert self.ep_process_group is not None, "Expert parallel group should be configured with MoE"
|
||
|
|
||
|
def _update_model_fp16_weights(self, group_index):
|
||
|
updated_params = self.unflatten(self.fp16_groups_flat[group_index],
|
||
|
self.round_robin_fp16_groups[group_index])
|
||
|
for p, q in zip(self.round_robin_fp16_groups[group_index], updated_params):
|
||
|
p.data = q.data
|
||
|
|
||
|
# set model fp16 weight to slices of reordered flattened buffer
|
||
|
for param_index, param in enumerate(self.fp16_groups[group_index]):
|
||
|
new_index = self.round_robin_fp6_indices[group_index][param_index]
|
||
|
param.data = self.round_robin_fp16_groups[group_index][new_index].data
|
||
|
|
||
|
def _round_robin_reorder(self, tensor_list, num_partitions):
|
||
|
|
||
|
# disable round robin if need to debug something
|
||
|
# return tensor_list, list(range(len(tensor_list)))
|
||
|
|
||
|
partition_tensors = {}
|
||
|
|
||
|
for i, tensor in enumerate(tensor_list):
|
||
|
j = i % num_partitions
|
||
|
if not j in partition_tensors:
|
||
|
partition_tensors[j] = []
|
||
|
partition_tensors[j].append((i, tensor))
|
||
|
|
||
|
reordered_tensors = []
|
||
|
reordered_indices = {}
|
||
|
|
||
|
for partition_index in partition_tensors.keys():
|
||
|
for i, (original_index, tensor) in enumerate(partition_tensors[partition_index]):
|
||
|
reordered_indices[original_index] = len(reordered_tensors)
|
||
|
reordered_tensors.append(tensor)
|
||
|
|
||
|
return reordered_tensors, reordered_indices
|
||
|
|
||
|
def _release_ipg_buffers(self):
|
||
|
if self.contiguous_gradients:
|
||
|
self.ipg_buffer = None
|
||
|
self.grads_in_partition = None
|
||
|
self.grads_in_partition_offset = 0
|
||
|
|
||
|
def initialize_optimizer_states(self):
|
||
|
|
||
|
for i, group in enumerate(self.fp16_groups):
|
||
|
single_grad_partition = torch.zeros(
|
||
|
int(self.partition_size[i]),
|
||
|
dtype=self.single_partition_of_fp32_groups[i].dtype,
|
||
|
device=self.device)
|
||
|
self.single_partition_of_fp32_groups[
|
||
|
i].grad = single_grad_partition.pin_memory(
|
||
|
) if self.cpu_offload else single_grad_partition
|
||
|
|
||
|
self.optimizer.step()
|
||
|
|
||
|
if not self.cpu_offload:
|
||
|
for group in self.single_partition_of_fp32_groups:
|
||
|
group.grad = None # class init
|
||
|
|
||
|
return
|
||
|
|
||
|
#########################################################################
|
||
|
#################### ZeRO Stage 1 - reduce gradients ####################
|
||
|
#########################################################################
|
||
|
|
||
|
def reduce_gradients(self, pipeline_parallel=False):
|
||
|
world_size = dist.get_world_size(self.dp_process_group)
|
||
|
my_rank = dist.get_rank(self.dp_process_group)
|
||
|
|
||
|
# with PP we must create ipg buffer, since backward is handled outside zero
|
||
|
if pipeline_parallel and self.contiguous_gradients:
|
||
|
self.ipg_buffer = []
|
||
|
buf_0 = torch.empty(int(self.reduce_bucket_size),
|
||
|
dtype=self.dtype,
|
||
|
device=torch.cuda.current_device())
|
||
|
self.ipg_buffer.append(buf_0)
|
||
|
self.ipg_index = 0
|
||
|
|
||
|
if not self.overlap_comm:
|
||
|
for i, group in enumerate(self.fp16_groups):
|
||
|
for param in group:
|
||
|
if param.grad is not None:
|
||
|
self.reduce_ready_partitions_and_remove_grads(param, i)
|
||
|
|
||
|
# reduce any pending grads in either hook/non-hook case
|
||
|
self.overlapping_partition_gradients_reduce_epilogue()
|
||
|
|
||
|
#########################################################################
|
||
|
#########################ZeRO Partition Gradients########################
|
||
|
#########################################################################
|
||
|
|
||
|
def get_first_param_index(self, group_id, param_group, partition_id):
|
||
|
for index, param in enumerate(param_group):
|
||
|
param_id = self.get_param_id(param)
|
||
|
if partition_id in self.param_to_partition_ids[group_id][param_id]:
|
||
|
return index
|
||
|
return None
|
||
|
|
||
|
def initialize_gradient_partitioning_data_structures(self):
|
||
|
|
||
|
for i, param_group in enumerate(self.round_robin_fp16_groups):
|
||
|
|
||
|
total_partitions = dist.get_world_size(
|
||
|
group=self.real_dp_process_group[i])
|
||
|
|
||
|
self.param_to_partition_ids[i] = {}
|
||
|
self.is_partition_reduced[i] = {}
|
||
|
self.total_grads_in_partition[i] = {}
|
||
|
self.remaining_grads_in_partition[i] = {}
|
||
|
self.is_grad_computed[i] = {}
|
||
|
self.grad_partition_insertion_offset[i] = {}
|
||
|
self.grad_start_offset[i] = {}
|
||
|
self.first_param_index_in_partition[i] = {}
|
||
|
|
||
|
for partition_id in range(total_partitions):
|
||
|
self.is_grad_computed[i][partition_id] = {}
|
||
|
self.grad_partition_insertion_offset[i][partition_id] = {}
|
||
|
self.grad_start_offset[i][partition_id] = {}
|
||
|
self.total_grads_in_partition[i][partition_id] = 0
|
||
|
self.initialize_gradient_partition(
|
||
|
i, param_group, partition_id)
|
||
|
self.is_partition_reduced[i][partition_id] = False
|
||
|
self.first_param_index_in_partition[i][
|
||
|
partition_id] = self.get_first_param_index(
|
||
|
i,
|
||
|
param_group,
|
||
|
partition_id)
|
||
|
|
||
|
def independent_gradient_partition_epilogue(self):
|
||
|
if self.verbose:
|
||
|
self.report_ipg_memory_usage(
|
||
|
f"In ipg_epilogue before reduce_ipg_grads", 0)
|
||
|
self.reduce_ipg_grads()
|
||
|
if self.verbose:
|
||
|
self.report_ipg_memory_usage(
|
||
|
f"In ipg_epilogue after reduce_ipg_grads", 0)
|
||
|
|
||
|
# if dist.get_rank() == 0:
|
||
|
# print()("Params already reduced %s", self.params_already_reduced)
|
||
|
for i in range(len(self.params_already_reduced)):
|
||
|
self.params_already_reduced[i] = False
|
||
|
|
||
|
if self.overlap_comm:
|
||
|
torch.cuda.synchronize()
|
||
|
# It is safe to clear previously reduced grads of other partitions
|
||
|
self._clear_previous_reduced_grads()
|
||
|
|
||
|
if self.cpu_offload is False:
|
||
|
for i, _ in enumerate(self.fp16_groups):
|
||
|
|
||
|
if not i in self.averaged_gradients or self.averaged_gradients[i] is None:
|
||
|
self.averaged_gradients[i] = self.get_flat_partition(
|
||
|
self.params_in_partition[i],
|
||
|
self.first_offset[i],
|
||
|
self.partition_size[i],
|
||
|
dtype=self.dtype,
|
||
|
device=torch.cuda.current_device(),
|
||
|
return_tensor_list=True)
|
||
|
else:
|
||
|
avg_new = self.get_flat_partition(self.params_in_partition[i],
|
||
|
self.first_offset[i],
|
||
|
self.partition_size[i],
|
||
|
dtype=self.dtype,
|
||
|
device=torch.cuda.current_device(),
|
||
|
return_tensor_list=True)
|
||
|
|
||
|
for accumulated_grad, new_avg_grad in zip(self.averaged_gradients[i], avg_new):
|
||
|
accumulated_grad.add_(new_avg_grad)
|
||
|
|
||
|
self._release_ipg_buffers()
|
||
|
|
||
|
# No need to keep the gradients anymore.
|
||
|
# All gradients required by the step
|
||
|
# are in self.averaged_gradients
|
||
|
self.zero_grad()
|
||
|
|
||
|
if self.verbose:
|
||
|
report_memory_usage(f"End ipg_epilogue")
|
||
|
|
||
|
# resets all partition to no reduced
|
||
|
# sets remaining grads to the total number of grads in each partition
|
||
|
# set is grad computed to false for all grads in partition
|
||
|
def reset_partition_gradient_structures(self):
|
||
|
for i, _ in enumerate(self.fp16_groups):
|
||
|
total_partitions = dist.get_world_size(
|
||
|
group=self.real_dp_process_group[i])
|
||
|
for partition_id in range(total_partitions):
|
||
|
self.is_partition_reduced[i][partition_id] = False
|
||
|
self.remaining_grads_in_partition[i][
|
||
|
partition_id] = self.total_grads_in_partition[i][partition_id]
|
||
|
|
||
|
for param_id in self.is_grad_computed[i][partition_id]:
|
||
|
self.is_grad_computed[i][partition_id][param_id] = False
|
||
|
|
||
|
def initialize_gradient_partition(self, i, param_group, partition_id):
|
||
|
def set_key_value_list(dictionary, key, value):
|
||
|
if key in dictionary:
|
||
|
dictionary[key].append(value)
|
||
|
else:
|
||
|
dictionary[key] = [value]
|
||
|
|
||
|
def increment_value(dictionary, key):
|
||
|
if key in dictionary:
|
||
|
dictionary[key] += 1
|
||
|
else:
|
||
|
dictionary[key] = 1
|
||
|
|
||
|
partition_size = self.partition_size[i]
|
||
|
|
||
|
start_index = partition_size * partition_id
|
||
|
end_index = partition_size * (partition_id + 1)
|
||
|
|
||
|
current_index = 0
|
||
|
first_offset = 0
|
||
|
|
||
|
for param in param_group:
|
||
|
|
||
|
param_size = param.numel()
|
||
|
param_id = self.get_param_id(param)
|
||
|
|
||
|
if (current_index >= start_index and current_index < end_index):
|
||
|
set_key_value_list(self.param_to_partition_ids[i],
|
||
|
param_id,
|
||
|
partition_id)
|
||
|
increment_value(self.total_grads_in_partition[i], partition_id)
|
||
|
|
||
|
self.is_grad_computed[i][partition_id][param_id] = False
|
||
|
|
||
|
self.grad_partition_insertion_offset[i][partition_id][
|
||
|
param_id] = current_index - start_index
|
||
|
self.grad_start_offset[i][partition_id][param_id] = 0
|
||
|
|
||
|
elif start_index > current_index and start_index < (current_index +
|
||
|
param_size):
|
||
|
assert (
|
||
3 years ago
|
first_offset == 0), "This can happen either zero or only once as this must be the first tensor in the partition"
|
||
3 years ago
|
first_offset = start_index - current_index
|
||
|
|
||
|
set_key_value_list(self.param_to_partition_ids[i],
|
||
|
param_id,
|
||
|
partition_id)
|
||
|
increment_value(self.total_grads_in_partition[i], partition_id)
|
||
|
|
||
|
self.is_grad_computed[i][partition_id][param_id] = False
|
||
|
|
||
|
self.grad_partition_insertion_offset[i][partition_id][param_id] = 0
|
||
|
self.grad_start_offset[i][partition_id][param_id] = first_offset
|
||
|
|
||
|
current_index = current_index + param_size
|
||
|
|
||
|
def overlapping_partition_gradients_reduce_epilogue(self):
|
||
|
self.independent_gradient_partition_epilogue()
|
||
|
|
||
|
def create_reduce_and_remove_grad_hooks(self):
|
||
|
self.grad_accs = []
|
||
|
for i, param_group in enumerate(self.fp16_groups):
|
||
|
for param in param_group:
|
||
|
if param.requires_grad:
|
||
|
def wrapper(param, i):
|
||
|
param_tmp = param.expand_as(param)
|
||
|
grad_acc = param_tmp.grad_fn.next_functions[0][0]
|
||
|
|
||
|
def reduce_partition_and_remove_grads(*notneeded):
|
||
|
self.reduce_ready_partitions_and_remove_grads(
|
||
|
param, i)
|
||
|
|
||
|
grad_acc.register_hook(
|
||
|
reduce_partition_and_remove_grads)
|
||
|
self.grad_accs.append(grad_acc)
|
||
|
|
||
|
wrapper(param, i)
|
||
|
|
||
|
def get_param_id(self, param):
|
||
|
unique_id = id(param)
|
||
|
return self.param_id[unique_id]
|
||
|
|
||
|
def report_ipg_memory_usage(self, tag, param_elems):
|
||
|
elem_count = self.elements_in_ipg_bucket + param_elems
|
||
|
percent_of_bucket_size = (
|
||
3 years ago
|
100.0 * elem_count) // self.reduce_bucket_size
|
||
3 years ago
|
if self.verbose:
|
||
|
report_memory_usage(
|
||
|
f"{tag}: elems in_bucket {self.elements_in_ipg_bucket} param {param_elems} max_percent {percent_of_bucket_size}"
|
||
|
)
|
||
|
|
||
|
# create a flat tensor aligned at the alignment boundary
|
||
|
def flatten_dense_tensors_aligned(self, tensor_list, alignment):
|
||
|
num_elements = 0
|
||
|
for tensor in tensor_list:
|
||
|
num_elements = num_elements + tensor.numel()
|
||
|
|
||
|
remaining = num_elements % alignment
|
||
|
|
||
|
if remaining:
|
||
|
elements_to_add = alignment - remaining
|
||
|
pad_tensor = torch.zeros(elements_to_add,
|
||
|
device=tensor_list[0].device,
|
||
|
dtype=tensor_list[0].dtype)
|
||
|
padded_tensor_list = tensor_list + [pad_tensor]
|
||
|
|
||
|
num_elements = num_elements + elements_to_add
|
||
|
else:
|
||
|
padded_tensor_list = tensor_list
|
||
|
|
||
|
return self.flatten(padded_tensor_list)
|
||
|
|
||
|
############### Independent Partition Gradient ########################
|
||
|
def reduce_independent_p_g_buckets_and_remove_grads(self, param, i):
|
||
|
if self.elements_in_ipg_bucket + param.numel() > self.reduce_bucket_size:
|
||
|
self.report_ipg_memory_usage("In ipg_remove_grads before reduce_ipg_grads",
|
||
|
param.numel())
|
||
|
self.reduce_ipg_grads()
|
||
|
if self.contiguous_gradients and self.overlap_comm:
|
||
|
# Swap ipg_index between 0 and 1
|
||
|
self.ipg_index = 1 - self.ipg_index
|
||
|
|
||
|
self.report_ipg_memory_usage("In ipg_remove_grads after reduce_ipg_grads",
|
||
|
param.numel())
|
||
|
|
||
|
param_id = self.get_param_id(param)
|
||
|
assert self.params_already_reduced[param_id] == False, \
|
||
|
f"The parameter {param_id} has already been reduced. \
|
||
|
Gradient computed twice for this partition. \
|
||
|
Multiple gradient reduction is currently not supported"
|
||
|
|
||
|
if param.numel() > self.reduce_bucket_size:
|
||
|
self.extra_large_param_to_reduce = param
|
||
|
|
||
|
elif self.contiguous_gradients:
|
||
|
# keeping the gradients contiguous to prevent memory fragmentation, and avoid flattening
|
||
|
new_grad_tensor = self.ipg_buffer[self.ipg_index].narrow(
|
||
|
0,
|
||
|
self.elements_in_ipg_bucket,
|
||
|
param.numel())
|
||
|
new_grad_tensor.copy_(param.grad.view(-1))
|
||
|
param.grad.data = new_grad_tensor.data.view_as(param.grad)
|
||
|
|
||
|
self.elements_in_ipg_bucket += param.numel()
|
||
|
|
||
|
assert param.grad is not None, f"rank {dist.get_rank()} - Invalid to reduce Param {param_id} with None gradient"
|
||
|
|
||
|
self.grads_in_ipg_bucket.append(param.grad)
|
||
|
self.params_in_ipg_bucket.append((i, param, param_id))
|
||
|
|
||
|
# make sure the average tensor function knows how to average the gradients
|
||
|
if is_moe_param(param):
|
||
|
self.ipg_bucket_has_moe_params = True
|
||
|
|
||
|
self.report_ipg_memory_usage("End ipg_remove_grads", 0)
|
||
|
|
||
|
def print_rank_0(self, message):
|
||
|
if dist.get_rank() == 0 and self.verbose:
|
||
|
print(message)
|
||
|
|
||
|
def gradient_reduction_w_predivide(self, tensor):
|
||
|
|
||
|
dp_world_size = dist.get_world_size(group=self.dp_process_group)
|
||
|
|
||
|
tensor_to_allreduce = tensor
|
||
|
|
||
|
if self.allreduce_always_fp32:
|
||
|
tensor_to_allreduce = tensor.float()
|
||
|
|
||
|
if self.postscale_gradients:
|
||
|
if self.gradient_predivide_factor != 1.0:
|
||
|
tensor_to_allreduce.mul_(1. / self.gradient_predivide_factor)
|
||
|
|
||
|
dist.all_reduce(tensor_to_allreduce, group=self.dp_process_group)
|
||
|
|
||
|
if self.gradient_predivide_factor != dp_world_size:
|
||
|
tensor_to_allreduce.mul_(
|
||
|
self.gradient_predivide_factor / dp_world_size)
|
||
|
else:
|
||
|
tensor_to_allreduce.div_(dp_world_size)
|
||
|
dist.all_reduce(tensor_to_allreduce, group=self.dp_process_group)
|
||
|
|
||
|
if self.allreduce_always_fp32 and tensor is not tensor_to_allreduce:
|
||
|
tensor.copy_(tensor_to_allreduce)
|
||
|
|
||
|
return tensor
|
||
|
|
||
|
def average_tensor(self, tensor):
|
||
|
if self.overlap_comm:
|
||
|
torch.cuda.synchronize()
|
||
|
stream = self.reduction_stream
|
||
|
else:
|
||
|
stream = torch.cuda.current_stream()
|
||
|
|
||
|
with torch.cuda.stream(stream):
|
||
|
if not self.reduce_scatter:
|
||
|
self.gradient_reduction_w_predivide(tensor)
|
||
|
return
|
||
|
|
||
|
# Accumulate destination ranks and bucket offsets for each gradient slice.
|
||
|
# Note: potential future optimization, record access pattern of parameters
|
||
|
# in backward pass and partition gradients w.r.t. access pattern so that our
|
||
|
# bucket is guaranteed to be contiguous w.r.t. ranks
|
||
|
rank_and_offsets = []
|
||
|
real_dp_process_group = []
|
||
|
curr_size = 0
|
||
|
prev_id = -1
|
||
|
|
||
|
process_group = self.dp_process_group
|
||
|
# count = 0
|
||
|
for i, param, param_id in self.params_in_ipg_bucket:
|
||
|
|
||
|
process_group = self.dp_process_group
|
||
|
# Averages gradients at parameter level if ipg has a moe param
|
||
|
# Otherwise averaging is done at the entire buffer level at the end of the loop
|
||
|
if self.ipg_bucket_has_moe_params:
|
||
|
process_group = self.expert_dp_process_group if is_moe_param(
|
||
|
param) else self.dp_process_group
|
||
|
param.grad.data.div_(
|
||
|
dist.get_world_size(group=process_group))
|
||
|
|
||
|
partition_ids = self.param_to_partition_ids[i][param_id]
|
||
|
partition_size = self.partition_size[i]
|
||
|
# Get all partition ids + their offsets
|
||
|
partition_ids_w_offsets = []
|
||
|
for partition_id in partition_ids:
|
||
|
offset = self.grad_start_offset[i][partition_id][param_id]
|
||
|
partition_ids_w_offsets.append((partition_id, offset))
|
||
|
partition_ids_w_offsets.sort(key=lambda t: t[1])
|
||
|
|
||
|
# Calculate rank and offsets for grad slices
|
||
|
for idx in range(len(partition_ids_w_offsets)):
|
||
|
partition_id, offset = partition_ids_w_offsets[idx]
|
||
|
|
||
|
# if dist.get_rank() == 0 and count < 100:
|
||
|
# print(f"Rank {dist.get_rank()} rank offet id {idx} calculated dp size {dist.get_world_size(group=process_group)} real dp size {dist.get_world_size(self.real_dp_process_group[i])} and dst: {partition_id}")
|
||
|
# count += 1
|
||
|
|
||
|
# Calculate numel for grad slice depending on partition location
|
||
|
if idx == len(partition_ids_w_offsets) - 1:
|
||
|
# Last partition_id uses its own offset
|
||
|
numel = param.numel() - offset
|
||
|
else:
|
||
|
# Set numel to next partition's offset
|
||
|
numel = partition_ids_w_offsets[idx + 1][1] - offset
|
||
|
|
||
|
# Merge bucket ranges if they belong to the same rank
|
||
|
if partition_id == prev_id:
|
||
|
prev_pid, prev_size, prev_numel = rank_and_offsets[-1]
|
||
|
rank_and_offsets[-1] = (prev_pid,
|
||
|
prev_size, prev_numel + numel)
|
||
|
else:
|
||
|
rank_and_offsets.append(
|
||
|
(partition_id, curr_size, numel))
|
||
|
real_dp_process_group.append(process_group)
|
||
|
curr_size += numel
|
||
|
prev_id = partition_id
|
||
|
|
||
|
if not self.ipg_bucket_has_moe_params:
|
||
|
tensor.div_(dist.get_world_size(group=self.dp_process_group))
|
||
|
|
||
|
async_handles = []
|
||
|
for i, (dst, bucket_offset, numel) in enumerate(rank_and_offsets):
|
||
|
grad_slice = tensor.narrow(0, int(bucket_offset), int(numel))
|
||
|
# if dist.get_rank() == 0:
|
||
|
# print(f"Rank {dist.get_rank()} rank offet id {i} real dp size {dist.get_world_size(group=real_dp_process_group[i])} and dst: {dst}")
|
||
|
# dist.barrier()
|
||
|
# dist.barrier()
|
||
|
dst_rank = _get_global_rank(real_dp_process_group[i], dst)
|
||
|
async_handle = dist.reduce(grad_slice,
|
||
|
dst=dst_rank,
|
||
|
group=real_dp_process_group[i],
|
||
|
async_op=True)
|
||
|
async_handles.append(async_handle)
|
||
|
|
||
|
for handle in async_handles:
|
||
|
handle.wait()
|
||
|
|
||
|
##############################################################################
|
||
|
############################# CPU Offload Methods#############################
|
||
|
##############################################################################
|
||
|
def get_grad_position(self, group_id, tensor_list, first_offset, partition_size):
|
||
|
current_offset = 0
|
||
|
|
||
|
for i, tensor in enumerate(tensor_list):
|
||
|
param_id = self.get_param_id(tensor)
|
||
|
param_start_offset = 0
|
||
|
|
||
|
num_elements = tensor.numel()
|
||
|
tensor_offset = 0
|
||
|
|
||
|
# we need to offset to get to the right element
|
||
|
if i == 0 and first_offset > 0:
|
||
|
tensor_offset = first_offset
|
||
|
num_elements = num_elements - tensor_offset
|
||
|
param_start_offset = first_offset
|
||
|
|
||
|
# we dont need all elements of the tensor
|
||
|
if num_elements > (partition_size - current_offset):
|
||
|
num_elements = partition_size - current_offset
|
||
|
|
||
|
self.grad_position[param_id] = [
|
||
|
int(group_id),
|
||
|
int(param_start_offset),
|
||
|
int(current_offset),
|
||
|
int(num_elements)
|
||
|
]
|
||
|
current_offset += num_elements
|
||
|
|
||
|
def update_overflow_tracker_for_param_grad(self, param):
|
||
|
if param.grad is not None and self._has_inf_or_nan(param.grad.data):
|
||
|
self.local_overflow = True
|
||
|
|
||
|
def async_accumulate_grad_in_cpu_via_gpu(self, param):
|
||
|
param_id = self.get_param_id(param)
|
||
|
|
||
|
[i, source_offset, dest_offset, num_elements] = self.grad_position[param_id]
|
||
|
|
||
|
# copy to a preexisiting buffer to avoid memory allocation penalty
|
||
|
dest_buffer = self.temp_grad_buffer_for_gpu_offload.view(-1).narrow(
|
||
|
0,
|
||
|
0,
|
||
|
param.numel())
|
||
|
|
||
|
# buffer for storing gradients for this parameter in CPU
|
||
|
def buffer_to_accumulate_to_in_cpu():
|
||
|
if not self.fp16_master_weights_and_gradients:
|
||
|
return torch.zeros(param.numel(),
|
||
|
dtype=param.dtype,
|
||
|
device=self.device).pin_memory()
|
||
|
else:
|
||
|
return self.single_partition_of_fp32_groups[i].grad.view(-1).narrow(
|
||
|
0,
|
||
|
dest_offset,
|
||
|
num_elements)
|
||
|
|
||
|
# accumulate gradients into param.grad or parts of it that belongs to this parittion
|
||
|
def accumulate_gradients():
|
||
|
if not self.fp16_master_weights_and_gradients:
|
||
|
dest_buffer.copy_(self.accumulated_grads_in_cpu[param_id].view(-1),
|
||
|
non_blocking=True)
|
||
|
param.grad.data.view(-1).add_(dest_buffer)
|
||
|
else:
|
||
|
dest_buffer.narrow(0,
|
||
|
source_offset,
|
||
|
num_elements).copy_(
|
||
|
self.accumulated_grads_in_cpu[param_id].view(-1),
|
||
|
non_blocking=True)
|
||
|
param.grad.data.view(-1).narrow(
|
||
|
0,
|
||
|
source_offset,
|
||
|
num_elements).add_(dest_buffer.narrow(0,
|
||
|
source_offset,
|
||
|
num_elements))
|
||
|
|
||
|
# move accumulated gradients back to CPU
|
||
|
def copy_gradients_to_cpu():
|
||
|
if not self.fp16_master_weights_and_gradients:
|
||
|
self.accumulated_grads_in_cpu[param_id].data.copy_(
|
||
|
param.grad.data.view(-1),
|
||
|
non_blocking=True)
|
||
|
else:
|
||
|
self.accumulated_grads_in_cpu[param_id].data.copy_(
|
||
|
param.grad.data.view(-1).narrow(0,
|
||
|
source_offset,
|
||
|
num_elements),
|
||
|
non_blocking=True)
|
||
|
|
||
|
if param_id not in self.accumulated_grads_in_cpu:
|
||
|
self.accumulated_grads_in_cpu[param_id] = buffer_to_accumulate_to_in_cpu(
|
||
|
)
|
||
|
|
||
|
if self.micro_step_id > 0:
|
||
|
accumulate_gradients()
|
||
|
|
||
|
# at the boundary we will send 32bit directly
|
||
|
if not self.is_gradient_accumulation_boundary:
|
||
|
copy_gradients_to_cpu()
|
||
|
|
||
|
def set_norm_for_param_grad(self, param):
|
||
|
param_id = self.get_param_id(param)
|
||
|
accumulated_grad = self.accumulated_grads_in_cpu[
|
||
|
param_id] if self.gradient_accumulation_steps > 1 else param.grad
|
||
|
|
||
|
[i, source_offset, dest_offset, num_elements] = self.grad_position[param_id]
|
||
|
|
||
|
start = source_offset
|
||
|
accumulated_grad = accumulated_grad.view(
|
||
|
-1).narrow(0, start, num_elements)
|
||
|
|
||
|
self.norm_for_param_grads[param_id] = accumulated_grad.data.double().norm(
|
||
|
2)
|
||
|
|
||
|
def set_norm_for_param_grad_in_gpu(self, param):
|
||
|
param_id = self.get_param_id(param)
|
||
|
accumulated_grad = param.grad
|
||
|
|
||
|
[i, source_offset, dest_offset, num_elements] = self.grad_position[param_id]
|
||
|
|
||
|
start = source_offset
|
||
|
accumulated_grad = accumulated_grad.view(
|
||
|
-1).narrow(0, start, num_elements)
|
||
|
|
||
|
self.norm_for_param_grads[param_id] = accumulated_grad.data.double().norm(
|
||
|
2)
|
||
|
|
||
|
def async_inplace_copy_grad_to_fp32_buffer_from_gpu(self, param):
|
||
|
param_id = self.get_param_id(param)
|
||
|
|
||
|
[i, source_offset, dest_offset, num_elements] = self.grad_position[param_id]
|
||
|
|
||
|
dest_tensor = self.single_partition_of_fp32_groups[i].grad.view(-1).narrow(
|
||
|
0,
|
||
|
dest_offset,
|
||
|
num_elements)
|
||
|
|
||
|
src_tensor = param.grad.view(-1).narrow(0, source_offset, num_elements)
|
||
|
if not self.fp16_master_weights_and_gradients:
|
||
|
src_tensor = src_tensor.float()
|
||
|
|
||
|
dest_tensor.copy_(src_tensor, non_blocking=True)
|
||
|
param.grad = None # offload only
|
||
|
|
||
|
def complete_grad_norm_calculation_for_cpu_offload(self, params):
|
||
|
total_norm = 0.0
|
||
|
norm_type = 2.0
|
||
|
for p in params:
|
||
|
if is_model_parallel_parameter(p) or (self.model_parallel_rank == 0):
|
||
|
param_id = self.get_param_id(p)
|
||
|
# as some model have trainable parameters but skipped in training,
|
||
|
# their backward hooks in self.create_reduce_and_remove_grad_hooks() will not run,
|
||
|
# so they have no norm_for_param_grads
|
||
|
if param_id in self.norm_for_param_grads:
|
||
|
param_norm = self.norm_for_param_grads[param_id]
|
||
|
total_norm += param_norm.item() ** 2
|
||
|
else:
|
||
|
# As unused parameters in modules may not be expected sometimes,
|
||
|
# add an explicit error msg when it occurred and an option to
|
||
|
# avoid the error
|
||
|
assert self.ignore_unused_parameters, """
|
||
|
This assert indicates that your module has parameters that
|
||
|
were not used in producing loss.
|
||
|
You can avoid this assert by
|
||
|
(1) enable ignore_unused_parameters option in zero_optimization config;
|
||
|
(2) making sure all trainable parameters and `forward` function
|
||
|
outputs participate in calculating loss.
|
||
|
"""
|
||
|
|
||
|
# Sum across all model parallel GPUs.
|
||
|
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
|
||
|
|
||
|
torch.distributed.all_reduce(total_norm_cuda,
|
||
|
op=torch.distributed.ReduceOp.SUM,
|
||
|
group=self.dp_process_group)
|
||
|
|
||
|
self._model_parallel_all_reduce(tensor=total_norm_cuda,
|
||
|
op=torch.distributed.ReduceOp.SUM)
|
||
|
|
||
|
total_norm = total_norm_cuda[0].item() ** (1. / norm_type)
|
||
|
|
||
|
if total_norm == float(
|
||
|
'inf') or total_norm == -float('inf') or total_norm != total_norm:
|
||
|
total_norm = -1
|
||
|
|
||
|
return total_norm
|
||
|
|
||
|
############################################################################################
|
||
|
|
||
|
def copy_grads_in_partition(self, param):
|
||
|
if self.cpu_offload:
|
||
|
|
||
|
if self.gradient_accumulation_steps > 1:
|
||
|
self.async_accumulate_grad_in_cpu_via_gpu(param)
|
||
|
|
||
|
if self.is_gradient_accumulation_boundary:
|
||
|
self.set_norm_for_param_grad_in_gpu(param)
|
||
|
|
||
|
self.update_overflow_tracker_for_param_grad(param)
|
||
|
|
||
|
self.async_inplace_copy_grad_to_fp32_buffer_from_gpu(param)
|
||
|
|
||
|
return
|
||
|
# print(f"ID {self.get_param_id(param)} grad norm {param.grad.norm()}")
|
||
|
if self.grads_in_partition is None:
|
||
|
self.grads_in_partition_offset = 0
|
||
|
total_size = 0
|
||
|
for group in self.params_in_partition:
|
||
|
for param_in_partition in group:
|
||
|
total_size += param_in_partition.numel()
|
||
|
|
||
|
if self.verbose:
|
||
|
report_memory_usage(
|
||
|
f"before copying {total_size} gradients into partition")
|
||
|
self.grads_in_partition = torch.empty(int(total_size),
|
||
|
dtype=self.dtype,
|
||
|
device=torch.cuda.current_device())
|
||
|
|
||
|
if self.verbose:
|
||
|
report_memory_usage(
|
||
|
f"after copying {total_size} gradients into partition")
|
||
|
|
||
|
# The allreduce buffer will be rewritted. Copy the gradients in partition to a new buffer
|
||
|
new_grad_tensor = self.grads_in_partition.view(-1).narrow(
|
||
|
0,
|
||
|
self.grads_in_partition_offset,
|
||
|
param.numel())
|
||
|
new_grad_tensor.copy_(param.grad.view(-1))
|
||
|
param.grad.data = new_grad_tensor.data.view_as(param.grad)
|
||
|
# print(f"Grad norm after copy to contiguous_buffer {param.grad.data.norm()}")
|
||
|
self.grads_in_partition_offset += param.numel()
|
||
|
|
||
|
def reduce_ipg_grads(self):
|
||
|
if self.contiguous_gradients:
|
||
|
if self.extra_large_param_to_reduce is not None:
|
||
|
assert len(
|
||
|
self.params_in_ipg_bucket) == 1, "more than 1 param in ipg bucket, this shouldn't happen"
|
||
|
_, _, param_id = self.params_in_ipg_bucket[0]
|
||
|
assert self.get_param_id(
|
||
|
self.extra_large_param_to_reduce) == param_id, "param in ipg bucket does not match extra-large param"
|
||
|
self.average_tensor(
|
||
|
self.extra_large_param_to_reduce.grad.view(-1))
|
||
|
self.extra_large_param_to_reduce = None
|
||
|
else:
|
||
|
self.average_tensor(self.ipg_buffer[self.ipg_index])
|
||
|
else:
|
||
|
self.buffered_reduce_fallback(
|
||
|
None,
|
||
|
self.grads_in_ipg_bucket,
|
||
|
elements_per_buffer=self.elements_in_ipg_bucket)
|
||
|
|
||
|
if self.overlap_comm:
|
||
|
stream = self.reduction_stream
|
||
|
elif self.cpu_offload:
|
||
|
# TODO: copy_grad_stream is disabled because of race with reduce. This hurts perf and should be fixed.
|
||
|
# torch.cuda.synchronize()
|
||
|
# stream = self.copy_grad_stream
|
||
|
stream = torch.cuda.current_stream()
|
||
|
else:
|
||
|
stream = torch.cuda.current_stream()
|
||
|
|
||
|
with torch.cuda.stream(stream):
|
||
|
for _, param, param_id in self.params_in_ipg_bucket:
|
||
|
|
||
|
assert self.params_already_reduced[param_id] == False, \
|
||
|
f"The parameter {param_id} has already been reduced. \
|
||
|
Gradient computed twice for this partition. \
|
||
|
Multiple gradient reduction is currently not supported"
|
||
|
|
||
|
self.params_already_reduced[param_id] = True
|
||
|
|
||
|
if self.partition_gradients:
|
||
|
if not self.is_param_in_current_partition[param_id]:
|
||
|
if self.overlap_comm and self.contiguous_gradients is False:
|
||
|
# Clear grads of other partitions during the next reduction
|
||
|
# to avoid clearing them before the reduction is complete.
|
||
|
if self.previous_reduced_grads is None:
|
||
|
self.previous_reduced_grads = []
|
||
|
self.previous_reduced_grads.append(param)
|
||
|
else:
|
||
|
param.grad = None # only if self.partition_gradients
|
||
|
elif self.contiguous_gradients:
|
||
|
self.copy_grads_in_partition(param)
|
||
|
|
||
|
self.grads_in_ipg_bucket = []
|
||
|
self.params_in_ipg_bucket = []
|
||
|
self.ipg_bucket_has_moe_params = False
|
||
|
self.elements_in_ipg_bucket = 0
|
||
|
#####################################################################
|
||
|
|
||
|
def reduce_ready_partitions_and_remove_grads(self, param, i):
|
||
|
if self.partition_gradients or self.is_gradient_accumulation_boundary:
|
||
|
self.reduce_independent_p_g_buckets_and_remove_grads(param, i)
|
||
|
|
||
|
def zero_reduced_gradients(self, partition_id, i):
|
||
|
def are_all_related_partitions_reduced(params_id):
|
||
|
for partition_id in self.param_to_partition_ids[i][params_id]:
|
||
|
if not self.is_partition_reduced[i][partition_id]:
|
||
|
return False
|
||
|
return True
|
||
|
|
||
|
for params_id in self.is_grad_computed[i][partition_id]:
|
||
|
if are_all_related_partitions_reduced(params_id):
|
||
|
self.param_dict[params_id].grad = None # dead code
|
||
|
|
||
|
def flatten_and_print(self, message, tensors, start=0, n=5):
|
||
|
flatten_tensor = self.flatten(tensors)
|
||
|
|
||
|
def print_func():
|
||
|
print(flatten_tensor.contiguous().view(-1).narrow(0, start, n))
|
||
|
|
||
|
self.sequential_execution(print_func, message)
|
||
|
|
||
|
def get_grads_to_reduce(self, i, partition_id):
|
||
|
def get_reducable_portion(key):
|
||
|
grad = self.param_dict[key].grad
|
||
|
total_elements = grad.numel()
|
||
|
start = self.grad_start_offset[i][partition_id][key]
|
||
|
num_elements = min(
|
||
|
total_elements - start,
|
||
|
self.partition_size[i] -
|
||
|
self.grad_partition_insertion_offset[i][partition_id][key])
|
||
|
if not pg_correctness_test:
|
||
|
if num_elements == total_elements:
|
||
|
return grad
|
||
|
else:
|
||
|
return grad.contiguous().view(-1).narrow(0,
|
||
|
int(start),
|
||
|
int(num_elements))
|
||
|
else:
|
||
|
if num_elements == total_elements:
|
||
|
return grad.clone()
|
||
|
else:
|
||
|
return grad.clone().contiguous().view(-1).narrow(
|
||
|
0,
|
||
|
int(start),
|
||
|
int(num_elements))
|
||
|
|
||
|
grads_to_reduce = []
|
||
|
for key in self.is_grad_computed[i][partition_id]:
|
||
|
grad = get_reducable_portion(key)
|
||
|
grads_to_reduce.append(grad)
|
||
|
return grads_to_reduce
|
||
|
|
||
|
def sequential_execution(self, function, message, group=None):
|
||
|
if group is None:
|
||
|
group = self.dp_process_group
|
||
|
if dist.get_rank(group=group) == 0:
|
||
|
print(message)
|
||
|
for id in range(dist.get_world_size(group=group)):
|
||
|
if id == dist.get_rank(group=group):
|
||
|
function()
|
||
|
dist.barrier(group=group)
|
||
|
|
||
|
def set_none_gradients_to_zero(self, i, partition_id):
|
||
|
for param_id in self.is_grad_computed[i][partition_id]:
|
||
|
param = self.param_dict[param_id]
|
||
|
if param.grad is None:
|
||
|
param.grad = torch.zero_like(param)
|
||
|
|
||
|
######################Reduction Related Methods##############################
|
||
|
|
||
|
def allreduce_bucket(self, bucket, allreduce_always_fp32=False, rank=None, log=None):
|
||
|
rank = None
|
||
|
tensor = self.flatten(bucket)
|
||
|
|
||
|
tensor_to_allreduce = tensor
|
||
|
|
||
|
if pg_correctness_test:
|
||
|
allreduce_always_fp32 = True
|
||
|
|
||
|
if allreduce_always_fp32:
|
||
|
tensor_to_allreduce = tensor.float()
|
||
|
|
||
|
tensor_to_allreduce.div_(
|
||
|
dist.get_world_size(group=self.dp_process_group))
|
||
|
|
||
|
if rank is None:
|
||
|
# "All Reducing"
|
||
|
dist.all_reduce(tensor_to_allreduce, group=self.dp_process_group)
|
||
|
else:
|
||
|
global_rank = _get_global_rank(self.dp_process_group, rank)
|
||
|
dist.reduce(tensor_to_allreduce, global_rank,
|
||
|
group=self.dp_process_group)
|
||
|
|
||
|
if allreduce_always_fp32 and tensor is not tensor_to_allreduce:
|
||
|
if rank is None or rank == dist.get_rank(group=self.dp_process_group):
|
||
|
tensor.copy_(tensor_to_allreduce)
|
||
|
|
||
|
return tensor
|
||
|
|
||
|
def _clear_previous_reduced_grads(self):
|
||
|
if self.previous_reduced_grads is not None:
|
||
|
for param in self.previous_reduced_grads:
|
||
|
param.grad = None # overlap enabled
|
||
|
self.previous_reduced_grads = None
|
||
|
|
||
|
# if rank is specified do a reduction instead of an allreduce
|
||
|
def allreduce_and_copy(self, small_bucket, rank=None, log=None):
|
||
|
if self.overlap_comm:
|
||
|
torch.cuda.synchronize()
|
||
|
# It is safe to clear the previously reduced grads of other partitions
|
||
|
self._clear_previous_reduced_grads()
|
||
|
stream = self.reduction_stream
|
||
|
else:
|
||
|
stream = torch.cuda.current_stream()
|
||
|
|
||
|
with torch.cuda.stream(stream):
|
||
|
allreduced = self.allreduce_bucket(
|
||
|
small_bucket, rank=rank, log=log)
|
||
|
if rank is None or rank == dist.get_rank(group=self.dp_process_group):
|
||
|
for buf, synced in zip(small_bucket, self.unflatten(allreduced, small_bucket)):
|
||
|
buf.copy_(synced)
|
||
|
|
||
|
def allreduce_no_retain(self,
|
||
|
bucket,
|
||
|
numel_per_bucket=500000000,
|
||
|
rank=None,
|
||
|
log=None):
|
||
|
small_bucket = []
|
||
|
numel = 0
|
||
|
for tensor in bucket:
|
||
|
small_bucket.append(tensor)
|
||
|
numel = numel + tensor.numel()
|
||
|
if numel > numel_per_bucket:
|
||
|
self.allreduce_and_copy(small_bucket, rank=rank, log=None)
|
||
|
small_bucket = []
|
||
|
|
||
|
if len(small_bucket) > 0:
|
||
|
self.allreduce_and_copy(small_bucket, rank=rank, log=log)
|
||
|
|
||
|
# allows using reduction of gradients instead of using all_reduce
|
||
|
|
||
|
def buffered_reduce_fallback(self,
|
||
|
rank,
|
||
|
grads,
|
||
|
elements_per_buffer=500000000,
|
||
|
log=None):
|
||
|
split_buckets = split_half_float_double(grads)
|
||
|
|
||
|
for i, bucket in enumerate(split_buckets):
|
||
|
self.allreduce_no_retain(bucket,
|
||
|
numel_per_bucket=elements_per_buffer,
|
||
|
rank=rank,
|
||
|
log=log)
|
||
|
|
||
|
#############################################################################
|
||
|
#############################################################################
|
||
|
#############################################################################
|
||
|
|
||
|
# views the tensor as multiple partitions and returns
|
||
|
# those partitions
|
||
|
def get_data_parallel_partitions(self, tensor, group_id):
|
||
|
partitions = []
|
||
|
|
||
|
dp = dist.get_world_size(group=self.real_dp_process_group[group_id])
|
||
|
dp_id = dist.get_rank(group=self.real_dp_process_group[group_id])
|
||
|
|
||
|
total_num_elements = tensor.numel()
|
||
|
|
||
|
base_size = total_num_elements // dp
|
||
|
remaining = total_num_elements % dp
|
||
|
|
||
|
start = 0
|
||
|
for id in range(dp):
|
||
|
partition_size = base_size
|
||
|
if id < remaining:
|
||
|
partition_size = partition_size + 1
|
||
|
partitions.append(tensor.narrow(0, start, partition_size))
|
||
|
start = start + partition_size
|
||
|
return partitions
|
||
|
|
||
|
def get_partition_info(self, tensor_list, partition_size, partition_id):
|
||
|
params_in_partition = []
|
||
|
params_not_in_partition = []
|
||
|
|
||
|
start_index = partition_size * partition_id
|
||
|
end_index = partition_size * (partition_id + 1)
|
||
|
|
||
|
current_index = 0
|
||
|
first_offset = 0
|
||
|
|
||
|
for tensor in tensor_list:
|
||
|
|
||
|
tensor_size = tensor.numel()
|
||
|
|
||
|
if (current_index >= start_index and current_index < end_index):
|
||
|
params_in_partition.append(tensor)
|
||
|
|
||
|
elif start_index > current_index and start_index < (current_index +
|
||
|
tensor_size):
|
||
|
params_in_partition.append(tensor)
|
||
|
|
||
|
assert (
|
||
3 years ago
|
first_offset == 0), "This can happen either zero or only once as this must be the first tensor in the partition"
|
||
3 years ago
|
first_offset = start_index - current_index
|
||
|
|
||
|
else:
|
||
|
params_not_in_partition.append(tensor)
|
||
|
|
||
|
current_index = current_index + tensor_size
|
||
|
|
||
|
return params_in_partition, params_not_in_partition, first_offset
|
||
|
|
||
|
def zero_grad(self, set_grads_to_None=True):
|
||
|
"""
|
||
|
Zero FP16 parameter grads.
|
||
|
"""
|
||
|
# FP32 grad should never exist.
|
||
|
# For speed, set model fp16 grad to None by default
|
||
|
for group in self.fp16_groups:
|
||
|
for p in group:
|
||
|
if set_grads_to_None:
|
||
|
p.grad = None # epilogue and in step
|
||
|
else:
|
||
|
if p.grad is not None:
|
||
|
p.grad.detach_()
|
||
|
p.grad.zero_()
|
||
|
|
||
|
def _model_parallel_all_reduce(self, tensor, op):
|
||
|
""" Perform all reduce within model parallel group, if any.
|
||
|
"""
|
||
|
if self.model_parallel_group is None:
|
||
|
pass
|
||
|
else:
|
||
|
torch.distributed.all_reduce(tensor=tensor,
|
||
|
op=op,
|
||
|
group=self.model_parallel_group)
|
||
|
|
||
3 years ago
|
def clip_grad_norm(self, *args, **kwargs):
|
||
|
# dummy function to retain the same function interface
|
||
|
# as ColossalaiOptimizer for compatibility
|
||
|
pass
|
||
|
|
||
3 years ago
|
def get_grad_norm_direct(self, gradients, params, norm_type=2):
|
||
|
"""Clips gradient norm of an iterable of parameters.
|
||
|
|
||
|
This is adapted from ``torch.nn.utils.clip_grad.clip_grad_norm_`` and
|
||
|
added functionality to handle model parallel parameters. Note that
|
||
|
the gradients are modified in place.
|
||
|
|
||
|
Arguments:
|
||
|
parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
|
||
|
single Tensor that will have gradients normalized
|
||
|
max_norm (float or int): max norm of the gradients
|
||
|
norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
|
||
|
infinity norm.
|
||
|
|
||
|
Returns:
|
||
|
Total norm of the parameters (viewed as a single vector).
|
||
|
"""
|
||
|
norm_type = float(norm_type)
|
||
|
if norm_type == inf:
|
||
|
total_norm = max(g.data.abs().max() for g in gradients)
|
||
|
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
|
||
|
torch.distributed.all_reduce(total_norm_cuda,
|
||
|
op=torch.distributed.ReduceOp.MAX,
|
||
|
group=self.dp_process_group)
|
||
|
|
||
|
# Take max across all GPUs.
|
||
|
self._model_parallel_all_reduce(tensor=total_norm_cuda,
|
||
|
op=torch.distributed.ReduceOp.MAX)
|
||
|
total_norm = total_norm_cuda[0].item()
|
||
|
else:
|
||
|
total_norm = 0.0
|
||
|
# if dist.get_rank() == 0:
|
||
|
# print()(f"Total Norm begining {total_norm}")
|
||
|
for g, p in zip(gradients, params):
|
||
|
if is_model_parallel_parameter(p) or (self.model_parallel_rank == 0):
|
||
|
param_norm = g.data.double().norm(2)
|
||
|
total_norm += param_norm.item() ** 2
|
||
|
# Sum across all model parallel GPUs.
|
||
|
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
|
||
|
|
||
|
torch.distributed.all_reduce(total_norm_cuda,
|
||
|
op=torch.distributed.ReduceOp.SUM,
|
||
|
group=self.dp_process_group)
|
||
|
|
||
|
self._model_parallel_all_reduce(tensor=total_norm_cuda,
|
||
|
op=torch.distributed.ReduceOp.SUM)
|
||
|
|
||
|
total_norm = total_norm_cuda[0].item() ** (1. / norm_type)
|
||
|
|
||
|
if total_norm == float(
|
||
|
'inf') or total_norm == -float('inf') or total_norm != total_norm:
|
||
|
total_norm = -1
|
||
|
|
||
|
return total_norm
|
||
|
|
||
|
# creates a flat fused tensor from the tensor list starting at the first_offset
|
||
|
# in the first tensor of the list. If there are not enough elements in the tensor
|
||
|
# list then the flat tensor will be padded with zeros
|
||
|
def get_flat_partition(self,
|
||
|
tensor_list,
|
||
|
first_offset,
|
||
|
partition_size,
|
||
|
dtype,
|
||
|
device,
|
||
|
return_tensor_list=False):
|
||
|
flat_tensor_list = []
|
||
|
current_size = 0
|
||
|
for i, tensor in enumerate(tensor_list):
|
||
|
if tensor.grad is None:
|
||
|
tensor.grad = torch.zeros_like(tensor)
|
||
|
|
||
|
tensor = tensor.grad
|
||
|
num_elements = tensor.numel()
|
||
|
tensor_offset = 0
|
||
|
|
||
|
# we need to offset to get to the right element
|
||
|
if i == 0 and first_offset > 0:
|
||
|
tensor_offset = first_offset
|
||
|
num_elements = num_elements - tensor_offset
|
||
|
|
||
|
# we dont need all elements of the tensor
|
||
|
if num_elements > (partition_size - current_size):
|
||
|
num_elements = partition_size - current_size
|
||
|
|
||
|
# we need a narrow view of the tensor based on the tensor offset and number of elements that
|
||
|
# we need from this tensor
|
||
|
if tensor_offset > 0 or num_elements < tensor.numel():
|
||
|
flat_tensor_list.append(tensor.contiguous().view(-1).narrow(
|
||
|
0,
|
||
|
int(tensor_offset),
|
||
|
int(num_elements)))
|
||
|
else:
|
||
|
flat_tensor_list.append(tensor)
|
||
|
|
||
|
current_size = current_size + num_elements
|
||
|
|
||
|
# this means its the last partition and does not align with the dp boundary. We need to pad before flattening
|
||
|
if current_size < partition_size:
|
||
|
flat_tensor_list.append(
|
||
|
torch.zeros(int(partition_size - current_size),
|
||
|
dtype=dtype,
|
||
|
device=device))
|
||
|
|
||
|
if return_tensor_list:
|
||
|
return flat_tensor_list
|
||
|
|
||
|
return self.flatten(flat_tensor_list)
|
||
|
|
||
|
def free_grad_in_param_list(self, param_list):
|
||
|
for p in param_list:
|
||
|
p.grad = None # in step
|
||
|
|
||
|
def reset_cpu_buffers(self):
|
||
|
self.norm_for_param_grads = {}
|
||
|
self.local_overflow = False
|
||
|
|
||
|
def log_timers(self, timer_names):
|
||
|
if self.timers is None:
|
||
|
return
|
||
|
|
||
|
self.timers.log(names=list(timer_names))
|
||
|
|
||
|
def start_timers(self, timer_names):
|
||
|
if self.timers is None:
|
||
|
return
|
||
|
|
||
|
for name in timer_names:
|
||
|
self.timers(name).start()
|
||
|
|
||
|
def stop_timers(self, timer_names):
|
||
|
if self.timers is None:
|
||
|
return
|
||
|
|
||
|
for name in timer_names:
|
||
|
self.timers(name).stop()
|
||
|
|
||
|
def step(self, closure=None):
|
||
|
"""
|
||
|
Not supporting closure.
|
||
|
"""
|
||
|
self.micro_step_id = -1
|
||
|
|
||
|
if self.verbose:
|
||
|
report_memory_usage(f"In step before checking overflow")
|
||
|
|
||
|
# First compute norm for all group so we know if there is overflow
|
||
|
self.check_overflow(self.partition_gradients)
|
||
|
|
||
|
OPTIMIZER_ALLGATHER = 'optimizer_allgather'
|
||
|
OPTIMIZER_GRADIENTS = 'optimizer_gradients'
|
||
|
OPTIMIZER_STEP = 'optimizer_step'
|
||
|
timer_names = [OPTIMIZER_ALLGATHER,
|
||
|
OPTIMIZER_GRADIENTS, OPTIMIZER_STEP]
|
||
|
|
||
|
prev_scale = self.loss_scale
|
||
|
self._update_scale(self.overflow)
|
||
|
if self.overflow:
|
||
|
if self.verbose:
|
||
|
report_memory_usage('After overflow before clearing gradients')
|
||
|
self.zero_grad()
|
||
|
if self.cpu_offload:
|
||
|
self.reset_cpu_buffers()
|
||
|
else:
|
||
|
self.averaged_gradients = {}
|
||
|
|
||
|
if self.verbose:
|
||
|
report_memory_usage('After overflow after clearing gradients')
|
||
|
|
||
|
print(
|
||
|
"[deepspeed] fp16 dynamic loss scale overflow! Rank {} Skipping step. Attempted loss scale: {}, "
|
||
|
"reducing to {}".format(dist.get_rank(),
|
||
|
prev_scale,
|
||
|
self.loss_scale))
|
||
|
self.start_timers(timer_names)
|
||
|
self.stop_timers(timer_names)
|
||
|
return
|
||
|
|
||
|
self.start_timers([OPTIMIZER_GRADIENTS])
|
||
|
norm_groups = []
|
||
|
single_partition_grad_groups = []
|
||
|
skip = False
|
||
|
for i, group in enumerate(self.fp16_groups):
|
||
|
partition_id = dist.get_rank(group=self.real_dp_process_group[i])
|
||
|
if self.cpu_offload:
|
||
|
norm_groups.append(
|
||
|
self.complete_grad_norm_calculation_for_cpu_offload(
|
||
|
self.params_in_partition[i]))
|
||
|
single_grad_partition = self.single_partition_of_fp32_groups[i].grad
|
||
|
else:
|
||
|
norm_groups.append(
|
||
|
self.get_grad_norm_direct(self.averaged_gradients[i],
|
||
|
self.params_in_partition[i]))
|
||
|
|
||
|
# free gradients for all the prameters that are not updated by this process
|
||
|
self.free_grad_in_param_list(self.params_not_in_partition[i])
|
||
|
|
||
|
# create a flat gradients for parameters updated by this process
|
||
|
# If we are last partition, ensure we have same size grads and partition size, if not pad with zero tensors
|
||
|
if partition_id == dist.get_world_size(
|
||
|
group=self.real_dp_process_group[i]) - 1:
|
||
|
single_grad_partition = self.flatten_dense_tensors_aligned(
|
||
|
self.averaged_gradients[i],
|
||
|
int(self.partition_size[i])).to(
|
||
|
self.single_partition_of_fp32_groups[i].dtype)
|
||
|
else:
|
||
|
single_grad_partition = self.flatten(self.averaged_gradients[i]).to(
|
||
|
self.single_partition_of_fp32_groups[i].dtype)
|
||
|
assert single_grad_partition.numel() == self.partition_size[i], \
|
||
|
"averaged gradients have different number of elements that partition size {} {} {} {}".format(
|
||
|
single_grad_partition.numel(), self.partition_size[i], i, partition_id)
|
||
|
|
||
|
self.single_partition_of_fp32_groups[i].grad = single_grad_partition
|
||
|
# release all the gradient since we have already created a necessary copy in dp_grad_partition
|
||
|
self.free_grad_in_param_list(self.params_in_partition[i])
|
||
|
|
||
|
self.averaged_gradients[i] = None
|
||
|
|
||
|
single_partition_grad_groups.append(single_grad_partition)
|
||
|
|
||
|
if self.has_moe_layers:
|
||
|
self._average_expert_grad_norms(norm_groups)
|
||
|
|
||
|
self.unscale_and_clip_grads(single_partition_grad_groups, norm_groups)
|
||
|
self.stop_timers([OPTIMIZER_GRADIENTS])
|
||
|
|
||
|
self.start_timers([OPTIMIZER_STEP])
|
||
|
if self.deepspeed_adam_offload:
|
||
|
from deepspeed.ops.adam import DeepSpeedCPUAdam
|
||
|
if type(self.optimizer) == DeepSpeedCPUAdam and self.dtype == torch.half:
|
||
|
fp16_param_groups = [
|
||
|
fp16_partitions[partition_id]
|
||
|
for fp16_partitions in self.parallel_partitioned_fp16_groups
|
||
|
]
|
||
|
self.optimizer.step(fp16_param_groups=fp16_param_groups)
|
||
|
else:
|
||
|
self.optimizer.step()
|
||
|
for fp16_partitions, fp32_partition in zip(self.parallel_partitioned_fp16_groups,
|
||
|
self.single_partition_of_fp32_groups):
|
||
|
fp16_partitions[partition_id].data.copy_(
|
||
|
fp32_partition.data)
|
||
|
else:
|
||
|
self.optimizer.step()
|
||
|
|
||
|
# get rid of the fp32 gradients. Not needed anymore
|
||
|
if not self.cpu_offload:
|
||
|
for group in self.single_partition_of_fp32_groups:
|
||
|
group.grad = None # in step
|
||
|
|
||
|
for fp16_partitions, fp32_partition in zip(self.parallel_partitioned_fp16_groups,
|
||
|
self.single_partition_of_fp32_groups):
|
||
|
fp16_partitions[partition_id].data.copy_(fp32_partition.data)
|
||
|
|
||
|
self.stop_timers([OPTIMIZER_STEP])
|
||
|
|
||
|
if self.cpu_offload:
|
||
|
self.reset_cpu_buffers()
|
||
|
|
||
|
self.start_timers([OPTIMIZER_ALLGATHER])
|
||
|
# gather the updated weights from everyone
|
||
|
for group_id, partitioned_params in enumerate(self.parallel_partitioned_fp16_groups):
|
||
|
|
||
|
# Sequential AllGather Best of both worlds
|
||
|
dp_world_size = dist.get_world_size(
|
||
|
group=self.real_dp_process_group[group_id])
|
||
|
num_shards = max(
|
||
|
1,
|
||
|
partitioned_params[partition_id].numel() * dp_world_size //
|
||
|
self.allgather_bucket_size)
|
||
|
|
||
|
shard_size = partitioned_params[partition_id].numel() // num_shards
|
||
|
num_elements = shard_size
|
||
|
|
||
|
assert shard_size * \
|
||
3 years ago
|
num_shards <= partitioned_params[partition_id].numel()
|
||
3 years ago
|
|
||
|
for shard_id in range(num_shards):
|
||
|
|
||
|
if shard_id == (num_shards - 1):
|
||
|
num_elements = partitioned_params[partition_id].numel(
|
||
|
) - shard_id * shard_size
|
||
|
|
||
|
shard_list = []
|
||
|
for dp_id in range(dp_world_size):
|
||
|
curr_shard = partitioned_params[dp_id].narrow(
|
||
|
0,
|
||
|
shard_id * shard_size,
|
||
|
num_elements).detach()
|
||
|
shard_list.append(curr_shard)
|
||
|
|
||
|
dist.all_gather(shard_list,
|
||
|
shard_list[partition_id],
|
||
|
group=self.real_dp_process_group[group_id])
|
||
|
self.stop_timers([OPTIMIZER_ALLGATHER])
|
||
|
|
||
|
# TODO: we probably don't need this? just to be safe
|
||
|
for i in range(len(norm_groups)):
|
||
|
self._update_model_fp16_weights(i)
|
||
|
|
||
|
self.log_timers(timer_names)
|
||
|
if self.verbose:
|
||
|
report_memory_usage('After zero_optimizer step')
|
||
|
|
||
|
return
|
||
|
|
||
|
def _average_expert_grad_norms(self, norm_groups):
|
||
|
for i, norm in enumerate(norm_groups):
|
||
|
if self.is_moe_param_group[i]:
|
||
|
scaled_norm = norm * 1.0 / float(
|
||
|
dist.get_world_size(group=self.ep_process_group))
|
||
|
scaled_norm_tensor = torch.tensor(scaled_norm,
|
||
|
device='cuda',
|
||
|
dtype=torch.float)
|
||
|
dist.all_reduce(scaled_norm_tensor,
|
||
|
group=self.ep_process_group)
|
||
|
norm_groups[i] = scaled_norm_tensor.item()
|
||
|
|
||
|
def unscale_and_clip_grads(self, grad_groups_flat, norm_groups):
|
||
|
total_norm = 0.0
|
||
|
for norm in norm_groups:
|
||
|
total_norm += norm ** 2.0
|
||
|
total_norm = math.sqrt(total_norm)
|
||
|
|
||
|
# compute combined scale factor for this group
|
||
|
combined_scale = self.loss_scale
|
||
|
if self.clip_grad > 0.:
|
||
|
# norm is in fact norm*scale
|
||
|
clip = ((total_norm / self.loss_scale) + 1e-6) / self.clip_grad
|
||
|
if clip > 1:
|
||
|
combined_scale = clip * self.loss_scale
|
||
|
|
||
|
for grad in grad_groups_flat:
|
||
|
if isinstance(grad, list):
|
||
|
sub_partitions = grad
|
||
|
for g in sub_partitions:
|
||
|
g.data.mul_(1. / combined_scale)
|
||
|
else:
|
||
|
grad.data.mul_(1. / combined_scale)
|
||
|
|
||
|
def _check_overflow(self, partition_gradients=True):
|
||
|
self.overflow = self.has_overflow(partition_gradients)
|
||
|
|
||
|
# `params` is a list / generator of torch.Variable
|
||
|
def has_overflow_serial(self, params, is_grad_list=False):
|
||
|
for p in params:
|
||
|
if p.grad is not None and self._has_inf_or_nan(p.grad.data):
|
||
|
return True
|
||
|
|
||
|
return False
|
||
|
|
||
|
def has_overflow_partitioned_grads_serial(self):
|
||
|
for i in range(len(self.fp16_groups)):
|
||
|
for j, grad in enumerate(self.averaged_gradients[i]):
|
||
|
if grad is not None and self._has_inf_or_nan(grad.data, j):
|
||
|
return True
|
||
|
return False
|
||
|
|
||
|
def has_overflow(self, partition_gradients=True):
|
||
|
if partition_gradients:
|
||
|
overflow = self.local_overflow if self.cpu_offload else self.has_overflow_partitioned_grads_serial(
|
||
|
)
|
||
|
overflow_gpu = torch.cuda.ByteTensor([overflow])
|
||
|
'''This will capture overflow across all data parallel and expert parallel process
|
||
|
Since expert parallel process are a subset of data parallel process'''
|
||
|
torch.distributed.all_reduce(overflow_gpu,
|
||
|
op=torch.distributed.ReduceOp.MAX,
|
||
|
group=self.dp_process_group)
|
||
|
|
||
|
else:
|
||
|
params = []
|
||
|
for group in self.fp16_groups:
|
||
|
for param in group:
|
||
|
params.append(param)
|
||
|
|
||
|
overflow = self.has_overflow_serial(
|
||
|
params, is_grad_list=partition_gradients)
|
||
|
overflow_gpu = torch.cuda.ByteTensor([overflow])
|
||
|
|
||
|
# Since each model parallel GPU carries only part of the model,
|
||
|
# make sure overflow flag is synced across all the model parallel GPUs
|
||
|
self._model_parallel_all_reduce(tensor=overflow_gpu,
|
||
|
op=torch.distributed.ReduceOp.MAX)
|
||
|
|
||
|
overflow = overflow_gpu[0].item()
|
||
|
return bool(overflow)
|
||
|
|
||
|
# `x` is a torch.Tensor
|
||
|
@staticmethod
|
||
|
def _has_inf_or_nan(x, j=None):
|
||
|
try:
|
||
|
# if x is half, the .float() incurs an additional deep copy, but it's necessary if
|
||
|
# Pytorch's .sum() creates a one-element tensor of the same type as x
|
||
|
# (which is true for some recent version of pytorch).
|
||
|
cpu_sum = float(x.float().sum())
|
||
|
# More efficient version that can be used if .sum() returns a Python scalar
|
||
|
# cpu_sum = float(x.sum())
|
||
|
except RuntimeError as instance:
|
||
|
# We want to check if inst is actually an overflow exception.
|
||
|
# RuntimeError could come from a different error.
|
||
|
# If so, we still want the exception to propagate.
|
||
|
if "value cannot be converted" not in instance.args[0]:
|
||
|
raise
|
||
|
return True
|
||
|
else:
|
||
|
if cpu_sum == float('inf') or cpu_sum == -float('inf') or cpu_sum != cpu_sum:
|
||
|
return True
|
||
|
return False
|
||
|
|
||
|
def backward(self, loss, retain_graph=False):
|
||
|
"""
|
||
|
:attr:`backward` performs the following steps:
|
||
|
|
||
|
1. fp32_loss = loss.float()
|
||
|
2. scaled_loss = fp32_loss*loss_scale
|
||
|
3. scaled_loss.backward(), which accumulates scaled gradients into the ``.grad`` attributes of the model's fp16 leaves
|
||
|
"""
|
||
|
self.micro_step_id += 1
|
||
|
|
||
|
if self.contiguous_gradients:
|
||
|
self.ipg_buffer = []
|
||
|
buf_0 = torch.empty(int(self.reduce_bucket_size),
|
||
|
dtype=self.dtype,
|
||
|
device=torch.cuda.current_device())
|
||
|
self.ipg_buffer.append(buf_0)
|
||
|
|
||
|
# Use double buffers to avoid data access conflict when overlap_comm is enabled.
|
||
|
if self.overlap_comm:
|
||
|
buf_1 = torch.empty(int(self.reduce_bucket_size),
|
||
|
dtype=self.dtype,
|
||
|
device=torch.cuda.current_device())
|
||
|
self.ipg_buffer.append(buf_1)
|
||
|
self.ipg_index = 0
|
||
|
|
||
|
self.loss_scaler.backward(loss.float(), retain_graph=retain_graph)
|
||
|
|
||
|
def check_overflow(self, partition_gradients=True):
|
||
|
self._check_overflow(partition_gradients)
|
||
|
|
||
|
def _update_scale(self, has_overflow=False):
|
||
|
self.loss_scaler.update_scale(has_overflow)
|
||
|
|
||
|
# Promote state so it can be retrieved or set via "fp16_optimizer_instance.state"
|
||
|
def _get_state(self):
|
||
|
return self.optimizer.state
|
||
|
|
||
|
def _set_state(self, value):
|
||
|
self.optimizer.state = value
|
||
|
|
||
|
state = property(_get_state, _set_state)
|
||
|
|
||
|
# Promote param_groups so it can be retrieved or set via "fp16_optimizer_instance.param_groups"
|
||
|
# (for example, to adjust the learning rate)
|
||
|
def _get_param_groups(self):
|
||
|
return self.optimizer.param_groups
|
||
|
|
||
|
def _set_param_groups(self, value):
|
||
|
self.optimizer.param_groups = value
|
||
|
|
||
|
param_groups = property(_get_param_groups, _set_param_groups)
|
||
|
|
||
|
# Promote loss scale so it can be retrieved or set via "fp16_optimizer_instance.loss_scale"
|
||
|
def _get_loss_scale(self):
|
||
|
return self.loss_scaler.loss_scale
|
||
|
|
||
|
def _set_loss_scale(self, value):
|
||
|
self.loss_scaler.cur_scale = value
|
||
|
|
||
|
loss_scale = property(_get_loss_scale, _set_loss_scale)
|
||
|
cur_scale = property(_get_loss_scale, _set_loss_scale)
|
||
|
|
||
|
# Return group tensor after removing paddings that are added for alignment to DP world size.
|
||
|
# This method works on the assumption that each group contains a single flattened tensor.
|
||
|
def _get_groups_without_padding(self, groups_with_padding):
|
||
|
groups_without_padding = []
|
||
|
for i, group in enumerate(groups_with_padding):
|
||
|
lean_length = group.numel() - self.groups_padding[i]
|
||
|
groups_without_padding.append(group[:lean_length])
|
||
|
|
||
|
return groups_without_padding
|
||
|
|
||
|
# Return optimizer state after removing paddings that are added for alignment.
|
||
|
def _get_state_without_padding(self, state_with_padding, padding):
|
||
|
lean_state = {}
|
||
|
for key, value in state_with_padding.items():
|
||
|
if torch.is_tensor(value):
|
||
|
lean_length = value.numel() - padding
|
||
|
lean_state[key] = value[:lean_length]
|
||
|
else:
|
||
|
lean_state[key] = value
|
||
|
|
||
|
return lean_state
|
||
|
|
||
|
# Return base optimizer states.
|
||
|
# This method assumes that each param group contains a single flattened tensor.
|
||
|
def _get_base_optimizer_state(self):
|
||
|
optimizer_groups_state = []
|
||
|
for i, group in enumerate(self.optimizer.param_groups):
|
||
|
p = group['params'][0]
|
||
|
lean_optimizer_state = self._get_state_without_padding(
|
||
|
self.optimizer.state[p],
|
||
|
self.groups_padding[i])
|
||
|
optimizer_groups_state.append(lean_optimizer_state)
|
||
|
|
||
|
return optimizer_groups_state
|
||
|
|
||
|
def state_dict(self):
|
||
|
"""
|
||
|
Returns a dict containing the current state of this :class:`FP16_Optimizer` instance.
|
||
|
This dict contains attributes of :class:`FP16_Optimizer`, as well as the state_dict
|
||
|
of the contained Pytorch optimizer.
|
||
|
|
||
|
Example::
|
||
|
|
||
|
checkpoint = {}
|
||
|
checkpoint['model'] = model.state_dict()
|
||
|
checkpoint['optimizer'] = optimizer.state_dict()
|
||
|
torch.save(checkpoint, "saved.pth")
|
||
|
"""
|
||
|
state_dict = {}
|
||
|
state_dict['loss_scaler'] = self.loss_scaler
|
||
|
state_dict['dynamic_loss_scale'] = self.dynamic_loss_scale
|
||
|
state_dict['overflow'] = self.overflow
|
||
|
state_dict['base_optimizer_state'] = self._get_base_optimizer_state()
|
||
|
|
||
|
state_dict['zero_stage'] = ZERO_OPTIMIZATION_GRADIENTS
|
||
|
state_dict['partition_count'] = self.partition_count
|
||
|
|
||
|
state_dict['ds_version'] = version
|
||
|
|
||
|
# Remove paddings for DP alignment to enable loading for other alignment values
|
||
|
fp32_groups_without_padding = self._get_groups_without_padding(
|
||
|
self.single_partition_of_fp32_groups)
|
||
|
state_dict['single_partition_of_fp32_groups'] = fp32_groups_without_padding
|
||
|
|
||
|
# if self.cpu_offload:
|
||
|
# state_dict_tmp = async_copy_to(state_dict,
|
||
|
# 'cpu',
|
||
|
# torch.cuda.current_stream())
|
||
|
# state_dict = state_dict_tmp
|
||
|
|
||
|
return state_dict
|
||
|
|
||
|
# Restore base optimizer fp32 weights from checkpoint by:
|
||
|
# 1) Merging fp32 weights from checkpoints of all partitions
|
||
|
# 2) Extracting fp32 weights for current partition from merged weights
|
||
|
# 3) Using extracted weights to update base optimizer weights directly.
|
||
|
def _restore_from_fp32_weights(self, all_state_dict):
|
||
|
merged_single_partition_of_fp32_groups = []
|
||
|
for i in range(len(self.single_partition_of_fp32_groups)):
|
||
|
partition_id = dist.get_rank(group=self.real_dp_process_group[i])
|
||
|
merged_partitions = [
|
||
|
sd['single_partition_of_fp32_groups'][i] for sd in all_state_dict
|
||
|
]
|
||
|
flat_merged_partitions = self.flatten_dense_tensors_aligned(
|
||
|
merged_partitions,
|
||
|
self.nccl_start_alignment_factor *
|
||
|
dist.get_world_size(group=self.real_dp_process_group[i]))
|
||
|
dp_partitions = self.get_data_parallel_partitions(
|
||
|
flat_merged_partitions, i)
|
||
|
merged_single_partition_of_fp32_groups.append(
|
||
|
dp_partitions[partition_id])
|
||
|
|
||
|
for current, saved in zip(self.single_partition_of_fp32_groups, merged_single_partition_of_fp32_groups):
|
||
|
current.data.copy_(saved.data)
|
||
|
|
||
|
# Restore base optimizer fp32 weights from ZeRO fp16 weights
|
||
|
def _restore_from_fp16_weights(self):
|
||
|
for group_id, fp16_partitions, fp32_partition in enumerate(
|
||
|
zip(self.parallel_partitioned_fp16_groups, self.single_partition_of_fp32_groups)):
|
||
|
partition_id = dist.get_rank(
|
||
|
group=self.real_dp_process_group[group_id])
|
||
|
fp32_partition.data.copy_(fp16_partitions[partition_id].data)
|
||
|
|
||
|
# Refresh the fp32 master params from the fp16 copies.
|
||
|
def refresh_fp32_params(self):
|
||
|
self._restore_from_fp16_weights()
|
||
|
|
||
|
# Extract optimizer state for current partition from merged states of all partitions
|
||
|
def _partition_base_optimizer_state(self, state_key, all_partition_states, group_id):
|
||
|
partition_id = dist.get_rank(
|
||
|
group=self.real_dp_process_group[group_id])
|
||
|
alignment = dist.get_world_size(
|
||
|
group=self.real_dp_process_group[group_id])
|
||
|
if torch.is_tensor(all_partition_states[0]):
|
||
|
flat_merged_partitions = self.flatten_dense_tensors_aligned(
|
||
|
all_partition_states,
|
||
|
alignment)
|
||
|
dp_partitions = self.get_data_parallel_partitions(flat_merged_partitions,
|
||
|
group_id)
|
||
|
return dp_partitions[partition_id]
|
||
|
else:
|
||
|
# Assume non-tensor states are not partitioned and equal across ranks, so return first one
|
||
|
return all_partition_states[0]
|
||
|
|
||
|
# Restore base optimizer state from checkpoint by
|
||
|
# 1) Merging optimizer state from checkpoints of all partitions
|
||
|
# 2) Extracting optimizer state for current partition from the merged state
|
||
|
# 3) Using the extracted value to directly update the base optimizer.
|
||
|
def _restore_base_optimizer_state(self, all_state_dict):
|
||
|
base_optimizer_group_states = []
|
||
|
for i in range(len(self.optimizer.param_groups)):
|
||
|
partition_states = {}
|
||
|
all_partition_group_states = [
|
||
|
sd['base_optimizer_state'][i] for sd in all_state_dict
|
||
|
]
|
||
|
for key in all_partition_group_states[0].keys():
|
||
|
all_partition_states = [
|
||
|
all_states[key] for all_states in all_partition_group_states
|
||
|
]
|
||
|
partition_states[key] = self._partition_base_optimizer_state(
|
||
|
key,
|
||
|
all_partition_states,
|
||
|
i)
|
||
|
base_optimizer_group_states.append(partition_states)
|
||
|
|
||
|
for i, group in enumerate(self.optimizer.param_groups):
|
||
|
p = group['params'][0]
|
||
|
for key, saved in base_optimizer_group_states[i].items():
|
||
|
if torch.is_tensor(self.optimizer.state[p][key]):
|
||
|
self.optimizer.state[p][key].data.copy_(saved.data)
|
||
|
else:
|
||
|
self.optimizer.state[p][key] = saved
|
||
|
|
||
|
def load_state_dict(self,
|
||
|
state_dict_list,
|
||
|
load_optimizer_states=True,
|
||
|
load_from_fp32_weights=False):
|
||
|
r"""Loading ZeRO checkpoint
|
||
|
|
||
|
Arguments:
|
||
|
state_dict_list: List of all saved ZeRO checkpoints, one for each saved partition.
|
||
|
Note that the number of saved partitions may differ from number of loading partitions to support
|
||
|
changing GPU count, specifically DP world size, between saving and loading checkpoints.
|
||
|
load_optimizer_states: Boolean indicating whether or not to load base optimizer states
|
||
|
load_from_fp32_weights: Boolean indicating whether to initialize fp32 master weights from fp32
|
||
|
copies in checkpoints (no precision loss) or from model's fp16 copies (with precision loss).
|
||
|
"""
|
||
|
"""
|
||
|
Loads a state_dict created by an earlier call to state_dict().
|
||
|
If ``fp16_optimizer_instance`` was constructed from some ``init_optimizer``,
|
||
|
whose parameters in turn came from ``model``, it is expected that the user
|
||
|
will call ``model.load_state_dict()`` before
|
||
|
``fp16_optimizer_instance.load_state_dict()`` is called.
|
||
|
|
||
|
Example::
|
||
|
|
||
|
model = torch.nn.Linear(D_in, D_out).cuda().half()
|
||
|
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
|
||
|
optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0)
|
||
|
...
|
||
|
checkpoint = torch.load("saved.pth")
|
||
|
model.load_state_dict(checkpoint['model'])
|
||
|
optimizer.load_state_dict(checkpoint['optimizer'])
|
||
|
"""
|
||
|
# I think it should actually be ok to reload the optimizer before the model.
|
||
|
self.loss_scaler = state_dict_list[0]['loss_scaler']
|
||
|
self.dynamic_loss_scale = state_dict_list[0]['dynamic_loss_scale']
|
||
|
self.overflow = state_dict_list[0]['overflow']
|
||
|
|
||
|
# zero stage 1 mode
|
||
|
if not self.partition_gradients:
|
||
|
required_version = pkg_version.parse("0.3.17")
|
||
|
ckpt_version = state_dict_list[0].get("ds_version", False)
|
||
|
error_str = f"ZeRO stage 1 changed in {required_version} and is not backwards compatible " \
|
||
|
"with older stage 1 checkpoints. If you'd like to load an old ZeRO-1 checkpoint " \
|
||
|
"please set 'legacy_stage1': true in your zero config json. This old version of " \
|
||
|
"stage 1 will be removed in v0.4.0."
|
||
|
|
||
|
assert ckpt_version, f"Empty ds_version! {error_str}"
|
||
|
assert required_version <= pkg_version.parse(
|
||
|
ckpt_version), f"Old version: {ckpt_version} {error_str}"
|
||
|
|
||
|
if load_optimizer_states:
|
||
|
self._restore_base_optimizer_state(state_dict_list)
|
||
|
|
||
|
# At this point, the optimizer's references to the model's fp32 parameters are up to date.
|
||
|
# The optimizer's hyperparameters and internal buffers are also up to date.
|
||
|
# However, the fp32 master copies of the model's fp16 params stored by the optimizer are still
|
||
|
# out of date. There are two options.
|
||
|
# 1: Refresh the master params from the model's fp16 params.
|
||
|
# This requires less storage but incurs precision loss.
|
||
|
# 2: Save and restore the fp32 master copies separately.
|
||
|
# We choose option 1 if changing DP degree and option 2 otherwise.
|
||
|
#
|
||
|
# Pytorch Optimizer.load_state_dict casts saved buffers (e.g. momentum) to the type and device
|
||
|
# of their associated parameters, because it's possible those buffers might not exist yet in
|
||
|
# the current optimizer instance. In our case, as long as the current FP16_Optimizer has been
|
||
|
# constructed in the same way as the one whose state_dict we are loading, the same master params
|
||
|
# are guaranteed to exist, so we can just copy_() from the saved master params.
|
||
|
|
||
|
if load_from_fp32_weights:
|
||
|
self._restore_from_fp32_weights(state_dict_list)
|
||
|
else:
|
||
|
self._restore_from_fp16_weights()
|
||
|
|
||
|
def allreduce_gradients(self):
|
||
|
self.overlapping_partition_gradients_reduce_epilogue()
|
||
|
|
||
|
|
||
|
def _handle_overflow(cpu_sum, x, i):
|
||
|
import math
|
||
|
rank = torch.distributed.get_rank()
|
||
|
if rank == 0:
|
||
|
t_i = -1
|
||
|
for v_i, v in enumerate(x.data.contiguous().view(-1)):
|
||
|
if not math.isfinite(float(v)):
|
||
|
t_i = v_i
|
||
|
break
|
||
|
print(
|
||
|
f"rank {rank} detected overflow {cpu_sum} in tensor {i}:{t_i} shape {x.shape}"
|
||
|
)
|
||
|
|
||
|
|
||
|
def estimate_zero2_model_states_mem_needs(total_params,
|
||
|
num_gpus_per_node=1,
|
||
|
num_nodes=1,
|
||
|
cpu_offload=True,
|
||
|
additional_buffer_factor=1.5):
|
||
|
total_gpus = num_nodes * num_gpus_per_node
|
||
|
|
||
|
if cpu_offload:
|
||
|
gpu_mem = 2 * total_params
|
||
|
cpu_mem = total_params * \
|
||
3 years ago
|
max(4 * total_gpus, 16) * additional_buffer_factor
|
||
3 years ago
|
else:
|
||
|
gpu_mem = 4 * total_params + int(16 * total_params / total_gpus)
|
||
|
cpu_mem = total_params * 4 * num_gpus_per_node * additional_buffer_factor
|
||
|
|
||
|
return int(cpu_mem), int(gpu_mem)
|
||
|
|
||
|
|
||
|
def model_to_params(model):
|
||
|
# shared params calculated only once
|
||
|
total_params = sum(
|
||
|
dict((p.data_ptr(),
|
||
|
p.numel()) for p in model.parameters()).values())
|
||
|
return total_params
|
||
|
|
||
|
|
||
|
def estimate_zero2_model_states_mem_needs_all_live(model,
|
||
|
num_gpus_per_node=1,
|
||
|
num_nodes=1,
|
||
|
additional_buffer_factor=1.5):
|
||
|
"""
|
||
|
Print out estimates on memory usage requirements for ZeRO 2 params, optim states and gradients
|
||
|
for a given ``model`` and hardware setup.
|
||
|
|
||
|
If you have an actual model object, use this function and everything will be derived
|
||
|
automatically.
|
||
|
|
||
|
If it's a hypothetical model, use ``estimate_zero2_model_states_mem_needs_all_cold`` where you have to pass
|
||
|
the ``total_params`` explicitly.
|
||
|
|
||
|
Args:
|
||
|
- ``model``: ``nn.Module`` object
|
||
|
- ``num_gpus_per_node``: how many gpus per node (defaults to 1)
|
||
|
- ``num_nodes``: how many nodes (defaults to 1),
|
||
|
- ``additional_buffer_factor``: estimation factor (defaults to 1.5):
|
||
|
|
||
|
"""
|
||
|
|
||
|
total_params = model_to_params(model)
|
||
|
|
||
|
estimate_zero2_model_states_mem_needs_all_cold(
|
||
|
total_params=total_params,
|
||
|
num_gpus_per_node=num_gpus_per_node,
|
||
|
num_nodes=num_nodes,
|
||
|
additional_buffer_factor=additional_buffer_factor)
|
||
|
|
||
|
|
||
|
def estimate_zero2_model_states_mem_needs_all_cold(total_params,
|
||
|
num_gpus_per_node=1,
|
||
|
num_nodes=1,
|
||
|
additional_buffer_factor=1.5):
|
||
|
"""
|
||
|
Print out estimates on memory usage requirements for ZeRO 2 params, optim states and gradients
|
||
|
for a given ``model`` and hardware setup.
|
||
|
|
||
|
If it's a hypothetical model, use this function where you have to pass
|
||
|
the ``total_params`` and ``largest_layer_params`` explicitly.
|
||
|
|
||
|
If you have an actual model object, use ``estimate_zero2_model_states_mem_needs_all_live`` and everything
|
||
|
will be derived automatically.
|
||
|
|
||
|
Args:
|
||
|
- ``total_params``: total model params
|
||
|
- ``num_gpus_per_node``: how many gpus per node (defaults to 1)
|
||
|
- ``num_nodes``: how many nodes (defaults to 1),
|
||
|
- ``additional_buffer_factor``: estimation factor (defaults to 1.5):
|
||
|
|
||
|
"""
|
||
|
|
||
|
def format_options(cpu_offload):
|
||
|
enabled = []
|
||
|
enabled.append(f"cpu_offload={1 if cpu_offload else 0}")
|
||
|
return ", ".join(enabled)
|
||
|
|
||
|
nodes_str = "nodes" if num_nodes > 1 else "node"
|
||
|
gpus_str = "GPUs" if num_gpus_per_node > 1 else "GPU"
|
||
|
print(
|
||
|
"Estimated memory needed for params, optim states and gradients for a:\n"
|
||
|
f"HW: Setup with {num_nodes} {nodes_str}, {num_gpus_per_node} {gpus_str} per node.\n"
|
||
|
f"SW: Model with {int(total_params / 1e6)}M total params.")
|
||
|
print(" per CPU | per GPU | Options")
|
||
|
for cpu_offload in [True, False]:
|
||
|
cpu_mem, gpu_mem = estimate_zero2_model_states_mem_needs(
|
||
|
total_params=total_params,
|
||
|
num_gpus_per_node=num_gpus_per_node,
|
||
|
num_nodes=num_nodes,
|
||
|
cpu_offload=cpu_offload,
|
||
|
additional_buffer_factor=additional_buffer_factor
|
||
|
)
|
||
|
|
||
|
options_str = format_options(cpu_offload=cpu_offload)
|
||
|
print(
|
||
|
f" {cpu_mem / 2 ** 30:7.2f}GB | {gpu_mem / 2 ** 30:6.2f}GB | {options_str}")
|