mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
818 lines
35 KiB
818 lines
35 KiB
# this code is inspired by the DeepSpeed library and implemented with our own design from scratch
|
|
import copy
|
|
import math
|
|
import warnings
|
|
from typing import Any, Dict, Iterator, OrderedDict, Set, Tuple, Union
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
from packaging.version import Version
|
|
from torch.nn import Parameter
|
|
from torch.optim import Optimizer
|
|
from torch.distributed import ProcessGroup
|
|
|
|
from colossalai.amp.naive_amp.mixed_precision_mixin import BF16MixedPrecisionMixin, FP16MixedPrecisionMixin
|
|
from colossalai.checkpoint_io.utils import StateDictSharder
|
|
from colossalai.interface import OptimizerWrapper
|
|
from colossalai.logging import get_dist_logger
|
|
from colossalai.nn.optimizer import CPUAdam, FusedAdam, HybridAdam
|
|
from colossalai.utils import disposable, get_current_device, is_ddp_ignored
|
|
|
|
from .chunk import Chunk, ChunkManager
|
|
from .gemini_ddp import GeminiDDP
|
|
from colossalai.checkpoint_io.utils import gather_distributed_param
|
|
from colossalai.tensor.d_tensor import (
|
|
distribute_tensor,
|
|
distribute_tensor_with_customization,
|
|
init_tensor_as_customization_distributed,
|
|
get_device_mesh,
|
|
get_sharding_spec,
|
|
is_customized_distributed_tensor,
|
|
is_distributed_tensor,
|
|
get_global_shape,
|
|
init_as_dtensor
|
|
)
|
|
|
|
__all__ = ["GeminiOptimizer", "GeminiAdamOptimizer"]
|
|
|
|
_AVAIL_OPTIM_LIST = {FusedAdam, CPUAdam, HybridAdam}
|
|
|
|
|
|
class GeminiFP16MixedPrecisionMixin(FP16MixedPrecisionMixin):
|
|
def __init__(
|
|
self,
|
|
module: GeminiDDP,
|
|
initial_scale: float = 2**16,
|
|
min_scale: float = 1,
|
|
growth_factor: float = 2,
|
|
backoff_factor: float = 0.5,
|
|
growth_interval: int = 1000,
|
|
hysteresis: int = 2,
|
|
max_scale: float = 2**32,
|
|
) -> None:
|
|
super().__init__(
|
|
initial_scale, min_scale, growth_factor, backoff_factor, growth_interval, hysteresis, max_scale
|
|
)
|
|
self.module = module
|
|
|
|
def check_local_overflow(self) -> bool:
|
|
return self.module.overflow_counter > 0
|
|
|
|
def pre_zero_grad(self) -> None:
|
|
self.module.overflow_counter = 0
|
|
|
|
|
|
class GeminiOptimizer(OptimizerWrapper):
|
|
"""A wrapper for optimizer. ``GeminiDDP`` and ``GeminiOptimizer`` implement Zero Redundancy Optimizer (ZeRO state-3).
|
|
|
|
Note:
|
|
You must use ``GeminiDDP`` with ``GeminiOptimizer``.
|
|
|
|
Note:
|
|
Make sure you set ``placement_policy`` of ``GeminiManager`` to `"auto"`,
|
|
if you set ``gpu_margin_mem_ratio > 0``.
|
|
|
|
Args:
|
|
optim (Optimizer): An Optimizer instance.
|
|
module (GeminiDDP): A ``GeminiDDP`` instance.
|
|
gpu_margin_mem_ratio (float, optional): The ratio of GPU remaining memory (after the first forward-backward)
|
|
which will be used when using hybrid CPU optimizer.
|
|
This argument is meaningless when `placement_policy` of `GeminiManager` is not "auto".
|
|
Defaults to 0.0.
|
|
initial_scale (float, optional): Initial scale used by DynamicGradScaler. Defaults to 2**32.
|
|
min_scale (float, optional): Min scale used by DynamicGradScaler. Defaults to 1.
|
|
growth_factor (float, optional): Growth_factor used by DynamicGradScaler. Defaults to 2.
|
|
backoff_factor (float, optional): Backoff_factor used by DynamicGradScaler. Defaults to 0.5.
|
|
growth_interval (float, optional): Growth_interval used by DynamicGradScaler. Defaults to 1000.
|
|
hysteresis (float, optional): Hysteresis used by DynamicGradScaler. Defaults to 2.
|
|
max_scale (int, optional): Max_scale used by DynamicGradScaler. Defaults to 2**32.
|
|
max_norm (float, optional): The norm value used to clip gradient. Defaults to 0.0.
|
|
norm_type (float, optional): The type of norm used for gradient clipping. Currently, only L2-norm (norm_type=2.0)
|
|
is supported in GeminiOptimizer. Defaults to 2.0.
|
|
verbose (bool, optional): Whether to print verbose information, including grad overflow info. Defaults to False.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
optim: Optimizer,
|
|
module: GeminiDDP,
|
|
gpu_margin_mem_ratio: float = 0.0,
|
|
initial_scale: float = 2**32,
|
|
min_scale: float = 1,
|
|
growth_factor: float = 2,
|
|
backoff_factor: float = 0.5,
|
|
growth_interval: int = 1000,
|
|
hysteresis: int = 2,
|
|
max_scale: float = 2**32,
|
|
max_norm: float = 0.0,
|
|
norm_type: float = 2.0,
|
|
tp_group: ProcessGroup = None,
|
|
optimizer_params_info=None,
|
|
verbose: bool = False,
|
|
**defaults: Any,
|
|
):
|
|
super().__init__(optim)
|
|
assert isinstance(module, GeminiDDP)
|
|
assert type(optim) in _AVAIL_OPTIM_LIST, (
|
|
"You should use an optimizer in the available list:\n" f"{_AVAIL_OPTIM_LIST}"
|
|
)
|
|
self.module = module
|
|
self.gemini_manager = module.gemini_manager
|
|
self.chunk_manager: ChunkManager = self.gemini_manager.chunk_manager
|
|
self.param_to_range: Dict[Parameter, Tuple[int, int]] = dict()
|
|
self.param_to_chunk16: Dict[Parameter, Chunk] = dict()
|
|
self.chunk16_set: Set[Chunk] = set()
|
|
self.clipping_flag = max_norm > 0.0
|
|
self.max_norm = max_norm
|
|
self.tp_group = tp_group
|
|
self.optimizer_params_info = optimizer_params_info
|
|
self.tp_size = dist.get_world_size(tp_group) if tp_group is not None else 1
|
|
self.tp_rank = dist.get_rank(tp_group) if tp_group is not None else 0
|
|
self.verbose = verbose
|
|
self.param_groups_backup = list()
|
|
|
|
# Mapping from integer id to real/fake param tensor, used for checkpointing.
|
|
self.id_to_real_params: Dict[int, Parameter] = dict()
|
|
self.id_to_fake_params: Dict[int, Parameter] = dict()
|
|
|
|
if self.clipping_flag:
|
|
assert norm_type == 2.0, "GeminiOptimizer only supports L2 norm now"
|
|
|
|
ddp_param_list = []
|
|
for name, param in module.named_parameters():
|
|
if is_ddp_ignored(param):
|
|
if param.requires_grad:
|
|
warnings.warn(
|
|
f"Parameter `{name}` is ignored by DDP but requires gradient! "
|
|
"You should handle its optimizer update by yourself!"
|
|
)
|
|
else:
|
|
ddp_param_list.append(param)
|
|
|
|
for p in ddp_param_list:
|
|
chunk_16 = self.chunk_manager.get_chunk(p)
|
|
if chunk_16 not in self.chunk16_set:
|
|
chunk_16.l2_norm_flag = self.clipping_flag
|
|
self.chunk16_set.add(chunk_16)
|
|
|
|
self.__init__optimizer()
|
|
|
|
if module.mixed_precision is torch.float16:
|
|
self.mix_precision_mixin = GeminiFP16MixedPrecisionMixin(
|
|
module,
|
|
initial_scale=initial_scale,
|
|
min_scale=min_scale,
|
|
growth_factor=growth_factor,
|
|
backoff_factor=backoff_factor,
|
|
growth_interval=growth_interval,
|
|
hysteresis=hysteresis,
|
|
max_scale=max_scale,
|
|
)
|
|
elif module.mixed_precision is torch.bfloat16:
|
|
self.mix_precision_mixin = BF16MixedPrecisionMixin()
|
|
else:
|
|
raise RuntimeError(f"Unsupported mixed precision type: {module.mixed_precision}")
|
|
|
|
self._logger = get_dist_logger()
|
|
|
|
self.gpu_margin_mem_ratio: float = float(gpu_margin_mem_ratio)
|
|
assert 0.0 <= self.gpu_margin_mem_ratio <= 1.0, f"gpu_margin_mem_ratio must >=0.0 and <=1.0"
|
|
# Only move fp32 shards from CPU to GPU when user allows and inner optimizer is valid
|
|
# Inner optimizer must support optimizing hybrid (CPU and CUDA) tensors,
|
|
# and it must set `num_fp32_shards_per_param` correctly
|
|
self._should_move_fp32_params_h2d: bool = (
|
|
self.gemini_manager.is_cuda_margin_mem_avail
|
|
and self.gpu_margin_mem_ratio > 0.0
|
|
and getattr(optim, "num_fp32_shards_per_param", 0) >= 2
|
|
)
|
|
if self.gpu_margin_mem_ratio > 0.0 and not self.gemini_manager.is_cuda_margin_mem_avail:
|
|
self._logger.warning(f'gpu_margin_mem_ratio is meaningless when placement_policy is not "auto"', ranks=[0])
|
|
|
|
self._register_states = disposable(self._register_states_)
|
|
|
|
def _set_grad_ptr(self):
|
|
for group in self.param_groups:
|
|
for fake_param in group["params"]:
|
|
chunk16 = self.param_to_chunk16[fake_param]
|
|
begin, end = self.param_to_range[fake_param]
|
|
|
|
grad_chunk16 = chunk16 if self.module.reuse_fp16_chunk else chunk16.grad_chunk
|
|
fake_param.data = grad_chunk16.payload[begin:end]
|
|
fake_param.grad = fake_param.data
|
|
|
|
to_update_chunk = chunk16.paired_chunk if self.module.master_weights else chunk16
|
|
fake_param.data = to_update_chunk.payload[begin:end]
|
|
|
|
def _update_fp16_params(self):
|
|
none_tensor = torch.empty([0])
|
|
for group in self.param_groups:
|
|
for fake_param in group["params"]:
|
|
assert fake_param.grad is None
|
|
fake_param.data = none_tensor.to(fake_param.device)
|
|
|
|
for chunk16 in self.chunk16_set:
|
|
chunk16.optim_update()
|
|
|
|
def _clear_global_norm(self) -> None:
|
|
for c16 in self.chunk16_set:
|
|
grad_chunk = c16 if self.module.reuse_fp16_chunk else c16.grad_chunk
|
|
grad_chunk.l2_norm = None
|
|
|
|
def _calc_global_norm(self) -> float:
|
|
norm_sqr: float = 0.0
|
|
group_to_norm = dict()
|
|
for c16 in self.chunk16_set:
|
|
grad_chunk = c16 if self.module.reuse_fp16_chunk else c16.grad_chunk
|
|
assert grad_chunk.l2_norm is not None
|
|
|
|
if grad_chunk.is_gathered:
|
|
norm_sqr += grad_chunk.l2_norm
|
|
else:
|
|
# this chunk is sharded, use communication to collect total norm
|
|
if grad_chunk.torch_pg not in group_to_norm:
|
|
group_to_norm[grad_chunk.torch_pg] = 0.0
|
|
group_to_norm[grad_chunk.torch_pg] += grad_chunk.l2_norm
|
|
|
|
grad_chunk.l2_norm = None # clear l2 norm
|
|
|
|
comm_buffer = torch.zeros(1, dtype=torch.float, device=get_current_device())
|
|
for group, part_norm in group_to_norm.items():
|
|
comm_buffer.fill_(part_norm)
|
|
dist.all_reduce(comm_buffer, group=group)
|
|
norm_sqr += comm_buffer.item()
|
|
|
|
global_norm = math.sqrt(norm_sqr)
|
|
return global_norm
|
|
|
|
def _get_combined_scale(self):
|
|
div_scale = self.mix_precision_mixin.get_grad_div_scale()
|
|
|
|
if self.clipping_flag:
|
|
total_norm = self._calc_global_norm()
|
|
clip = ((total_norm / div_scale) + 1e-6) / self.max_norm
|
|
if clip > 1:
|
|
div_scale = clip * div_scale
|
|
|
|
return -1 if div_scale == 1.0 else div_scale
|
|
|
|
def zero_grad(self, *args, **kwargs):
|
|
self.mix_precision_mixin.pre_zero_grad()
|
|
return self.optim.zero_grad(set_to_none=True)
|
|
|
|
def step(self, *args, **kwargs):
|
|
if self.module.master_weights:
|
|
self._maybe_move_fp32_params()
|
|
self._set_grad_ptr()
|
|
|
|
if self.mix_precision_mixin.should_skip_step():
|
|
if self.verbose:
|
|
self._logger.info(f"Found overflow. Skip step")
|
|
self._clear_global_norm() # clear recorded norm
|
|
self.zero_grad() # reset all gradients
|
|
if self.module.reuse_fp16_chunk:
|
|
self._update_fp16_params()
|
|
return
|
|
|
|
# get combined scale. combined scale = loss scale * clipping norm
|
|
# so that gradient = gradient / combined scale
|
|
combined_scale = self._get_combined_scale()
|
|
|
|
ret = self.optim.step(div_scale=combined_scale, *args, **kwargs)
|
|
self._register_states()
|
|
self.zero_grad()
|
|
if self.module.master_weights:
|
|
self._update_fp16_params()
|
|
self.module.accumulating_grads = False
|
|
return ret
|
|
|
|
def clip_grad_norm(self, model: torch.nn.Module, max_norm: float, norm_type: float = 2.0):
|
|
raise NotImplementedError
|
|
|
|
def backward(self, loss: torch.Tensor):
|
|
loss = self.mix_precision_mixin.pre_backward(loss)
|
|
self.module.backward(loss)
|
|
|
|
def backward_by_grad(self, tensor: torch.Tensor, grad: torch.Tensor):
|
|
# This function is called except the last stage of pipeline parallel
|
|
# It receives the scaled grad from the previous rank
|
|
# No need to scale the grad again
|
|
# Need to unscale when optimizing
|
|
grad = self.mix_precision_mixin.pre_backward_by_grad(grad)
|
|
self.module.backward_by_grad(tensor, grad)
|
|
|
|
def _maybe_move_fp32_params(self):
|
|
if self._should_move_fp32_params_h2d:
|
|
self._should_move_fp32_params_h2d = False
|
|
available_cuda_margin_mem = self.gemini_manager.cuda_margin_mem * self.gpu_margin_mem_ratio
|
|
fp32_params_available_cuda_margin_mem = available_cuda_margin_mem / self.optim.num_fp32_shards_per_param
|
|
fp32_params_used_cuda_margin_mem = 0
|
|
|
|
for group in self.param_groups:
|
|
for fake_param in group["params"]:
|
|
chunk16 = self.param_to_chunk16[fake_param]
|
|
chunk32 = chunk16.paired_chunk
|
|
|
|
if chunk32.device_type == "cuda":
|
|
continue
|
|
|
|
if fp32_params_used_cuda_margin_mem + chunk32.payload_mem < fp32_params_available_cuda_margin_mem:
|
|
self.chunk_manager.move_chunk(chunk32, get_current_device())
|
|
# stores grad now
|
|
self.chunk_manager.move_chunk(chunk16, get_current_device())
|
|
self.module.set_chunk_grad_device(chunk16, get_current_device())
|
|
fp32_params_used_cuda_margin_mem += chunk32.payload_mem
|
|
|
|
for group in self.param_groups:
|
|
for fake_param in group["params"]:
|
|
chunk16 = self.param_to_chunk16[fake_param]
|
|
chunk32 = chunk16.paired_chunk
|
|
if chunk32.device_type == "cuda":
|
|
state = self.optim.state[fake_param]
|
|
for k, v in state.items():
|
|
if isinstance(v, torch.Tensor):
|
|
state[k] = v.to(get_current_device())
|
|
|
|
def _register_states_(self):
|
|
for group in self.optim.param_groups:
|
|
for p in group["params"]:
|
|
state = self.optim.state[p]
|
|
for val in state.values():
|
|
if isinstance(val, torch.Tensor):
|
|
self.chunk_manager.add_extern_static_tensor(val)
|
|
|
|
def __init__optimizer(self):
|
|
def get_range_pair(local_chunk: Chunk, local_param: Parameter):
|
|
param_info = local_chunk.tensors_info[local_param]
|
|
if local_chunk.keep_gathered:
|
|
return param_info.offset, param_info.end
|
|
begin = max(0, param_info.offset - local_chunk.shard_begin)
|
|
end = min(local_chunk.shard_size, param_info.end - local_chunk.shard_begin)
|
|
return begin, end
|
|
|
|
param_id = -1
|
|
for group in self.optim.param_groups:
|
|
fake_params_list = list()
|
|
group_backup = {k: v for k, v in group.items() if k != "params"}
|
|
group_ids = []
|
|
for param in group["params"]:
|
|
# Record the mapping of id to current param.
|
|
param_id += 1
|
|
self.id_to_real_params[param_id] = param
|
|
group_ids.append(param_id)
|
|
|
|
# If current param is controlled by current process, add it to fake_param.
|
|
if is_ddp_ignored(param):
|
|
continue
|
|
chunk16 = self.chunk_manager.get_chunk(param)
|
|
range_pair = get_range_pair(chunk16, param)
|
|
if range_pair[0] >= range_pair[1]:
|
|
continue
|
|
grad_device = self.module.grads_device[param]
|
|
fake_param = torch.nn.Parameter(torch.empty([0], device=grad_device))
|
|
self.param_to_chunk16[fake_param] = chunk16
|
|
self.param_to_range[fake_param] = range_pair
|
|
self.id_to_fake_params[param_id] = fake_param
|
|
fake_params_list.append(fake_param)
|
|
|
|
# Update self.optim.param_groups as well as backup group.
|
|
group["params"] = fake_params_list
|
|
group_backup["params"] = group_ids
|
|
self.param_groups_backup.append(group_backup)
|
|
|
|
def get_offsets(self, param_id: int) -> tuple:
|
|
"""
|
|
Args:
|
|
param_id(int): The id of parameter.
|
|
|
|
Returns:
|
|
chunk_offset(int): Offset of parameter inside the chunk.
|
|
shard_offset(int): Offset of its optimizer state shard
|
|
relative to the whole optimizer state.
|
|
shard_size(int): Length of parameter shard owned by current process.
|
|
"""
|
|
|
|
if param_id not in self.id_to_fake_params:
|
|
return -1, -1, -1
|
|
fake_param = self.id_to_fake_params[param_id]
|
|
chunk = self.param_to_chunk16[fake_param]
|
|
param = self.id_to_real_params[param_id]
|
|
param_info = chunk.tensors_info[param]
|
|
|
|
begin_in_chunk, end_in_chunk = self.param_to_range[fake_param]
|
|
chunk_offset = begin_in_chunk
|
|
if chunk.keep_gathered:
|
|
shard_offset = 0
|
|
else:
|
|
shard_offset = begin_in_chunk + chunk.shard_begin - param_info.offset
|
|
shard_size = end_in_chunk - begin_in_chunk
|
|
assert chunk_offset >= 0 and shard_offset >= 0
|
|
return chunk_offset, shard_offset, shard_size
|
|
|
|
def collect_states(self, param_id: int, only_rank_0: bool = True) -> dict:
|
|
"""
|
|
Args:
|
|
param_id (int): id of the parameter whose state is to be gathered at master rank.
|
|
only_rank_0(bool): if True, states will be collected only on master rank, otherwise collected on every rank.
|
|
|
|
Returns:
|
|
collected_states(dict): the gathered optimzier state of parameter with given id
|
|
if this method is called by master rank, otherwise an empty dict.
|
|
|
|
This method can work only when called by all processes simultaneously.
|
|
"""
|
|
|
|
# Get param & chunk & process group.
|
|
param = self.id_to_real_params[param_id]
|
|
fake_param = self.id_to_fake_params.get(param_id, None)
|
|
chunk = self.chunk_manager.get_chunk(param)
|
|
dp_group = chunk.torch_pg
|
|
rank = dist.get_rank(dp_group)
|
|
master_rank = 0
|
|
collected_states = {}
|
|
|
|
# Fetch names of states through all_gather.
|
|
local_state_names = None
|
|
if fake_param is not None:
|
|
local_state_names = list(self.optim.state[fake_param].keys())
|
|
gathered_state_names = [None for _ in range(dist.get_world_size(dp_group))]
|
|
dist.barrier()
|
|
dist.all_gather_object(gathered_state_names, local_state_names, dp_group)
|
|
state_names = None
|
|
for names in gathered_state_names:
|
|
if names is not None:
|
|
# Assume different devices share the same set of state names if they have.
|
|
state_names = copy.deepcopy(names)
|
|
break
|
|
|
|
# Directly return if this parameter doesn't have optimizer states.
|
|
# e.g. parameter freezed/layer dropped
|
|
if state_names is None:
|
|
return collected_states
|
|
|
|
# Boolean variable is_collector indicates that whether the current rank
|
|
# needs to gather the whole optimizer states.
|
|
# Only master rank is collector when only_rank_0 is True.
|
|
# Every rank is collector when only_rank_0 is False.
|
|
is_collector = (rank == master_rank) or (not only_rank_0)
|
|
|
|
# get tensor parallelism information
|
|
is_dtensor = is_distributed_tensor(param)
|
|
is_customized_distributed = is_customized_distributed_tensor(param)
|
|
shard_spec = get_sharding_spec(param) if is_dtensor else None
|
|
device_mesh = get_device_mesh(param) if is_dtensor else None
|
|
global_shape = self.optimizer_params_info["id2shape"][param_id]
|
|
|
|
# If the chunk is kept gathered,
|
|
# the parameteres are treated the same as that of those in strict DDP during training.
|
|
# So states can be directly fetched from current device.
|
|
if chunk.keep_gathered:
|
|
assert param_id in self.id_to_fake_params
|
|
if is_collector:
|
|
states = self.optim.state[fake_param]
|
|
for state_name in state_names:
|
|
if state_name == "step":
|
|
# To keep aligned with pytorch, state 'step' is stored as a pytorch tensor with type float32.
|
|
collected_states[state_name] = torch.tensor(
|
|
states["step"], dtype=torch.float32, requires_grad=False
|
|
).cpu()
|
|
else:
|
|
state_tensor = states[state_name].detach().clone().to(torch.float32).cpu()
|
|
if is_dtensor:
|
|
state_tensor = torch.reshape(state_tensor, param.shape).to(param.device)
|
|
state_tensor = init_as_dtensor(state_tensor,
|
|
device_mesh=device_mesh,
|
|
sharding_spec=shard_spec,
|
|
global_shape = global_shape)
|
|
elif is_customized_distributed:
|
|
state_tensor = torch.reshape(state_tensor, param.shape).to(param.device)
|
|
init_tensor_as_customization_distributed(state_tensor, shard_fn=param.shard_fn, gather_fn=param.gather_fn)
|
|
state_tensor = gather_distributed_param(state_tensor, keep_vars=False).cpu()
|
|
|
|
collected_states[state_name] = state_tensor.reshape(global_shape)
|
|
return collected_states
|
|
|
|
# Check whether the param with given id is managed by current process.
|
|
own_param = param_id in self.id_to_fake_params
|
|
|
|
# Collector gets prepared for state collecting.
|
|
if is_collector:
|
|
for state_name in state_names:
|
|
if state_name == "step":
|
|
# To keep aligned with pytorch, state 'step' is stored as a pytorch tensor with type float32.
|
|
collected_states[state_name] = torch.tensor(0.0, dtype=torch.float32, requires_grad=False).cpu()
|
|
else:
|
|
collected_states[state_name] = torch.zeros(
|
|
param.numel(), dtype=torch.float32, requires_grad=False
|
|
).cpu()
|
|
|
|
# Materials for gathering, including compacted state tensors, and the offset of shard inside each state.
|
|
compacted_states = self.pack_optimizer_states_to_tensor(param_id, state_names) if own_param else None
|
|
_, shard_offset, shard_size = self.get_offsets(param_id)
|
|
|
|
# Collectors gather state shards through all_gathering.
|
|
gathered_state_shards = [None for _ in range(dist.get_world_size(dp_group))]
|
|
|
|
dist.barrier()
|
|
dist.all_gather_object(gathered_state_shards, [compacted_states, shard_offset, shard_size])
|
|
|
|
if is_collector:
|
|
for state_shard in gathered_state_shards:
|
|
compacted_states = state_shard[0]
|
|
shard_offset = state_shard[1]
|
|
shard_size = state_shard[2]
|
|
if compacted_states is None:
|
|
continue
|
|
self.load_from_compacted_states(
|
|
compacted_states, collected_states, state_names, shard_offset, shard_size
|
|
)
|
|
|
|
# Reshape tensors
|
|
if is_collector:
|
|
for state_name, state_tensor in collected_states.items():
|
|
if state_tensor.numel() == param.numel():
|
|
collected_states[state_name] = torch.reshape(state_tensor, param.shape)
|
|
if is_dtensor:
|
|
state_tensor = state_tensor.to(param.device)
|
|
state_tensor = init_as_dtensor(state_tensor,
|
|
sharding_spec=shard_spec,
|
|
device_mesh=device_mesh,
|
|
global_shape=global_shape)
|
|
elif is_customized_distributed:
|
|
state_tensor = state_tensor.to(param.device)
|
|
init_tensor_as_customization_distributed(state_tensor, shard_fn=param.shard_fn, gather_fn=param.gather_fn)
|
|
state_tensor = gather_distributed_param(state_tensor, keep_vars=False).cpu()
|
|
|
|
return collected_states
|
|
|
|
def pack_optimizer_states_to_tensor(
|
|
self,
|
|
param_id: int,
|
|
state_names: list,
|
|
device: torch.device = torch.device("cuda"),
|
|
dtype: torch.dtype = torch.float32,
|
|
) -> torch.Tensor:
|
|
"""
|
|
With param id given, pack its optimizer states into a compact tensor and return.
|
|
"""
|
|
if param_id not in self.id_to_fake_params:
|
|
return None
|
|
|
|
fake_param = self.id_to_fake_params[param_id]
|
|
param_range = self.param_to_range[fake_param]
|
|
states = self.optim.state[fake_param]
|
|
shard_size = param_range[1] - param_range[0]
|
|
compacted_size = 0
|
|
for name in state_names:
|
|
if name == "step":
|
|
compacted_size += 1
|
|
else:
|
|
compacted_size += shard_size
|
|
compacted_states = torch.zeros(compacted_size, dtype=dtype, device=device, requires_grad=False)
|
|
|
|
next_state_offset = 0
|
|
for state_name, state_tensor in states.items():
|
|
# State 'step' needs special operation.
|
|
if state_name == "step":
|
|
if isinstance(state_tensor, torch.Tensor):
|
|
compacted_states[next_state_offset] = state_tensor[0].item()
|
|
else:
|
|
assert isinstance(state_tensor, int)
|
|
compacted_states[next_state_offset] = state_tensor
|
|
next_state_offset += 1
|
|
else:
|
|
assert state_tensor.numel() == shard_size
|
|
compacted_states[next_state_offset : next_state_offset + shard_size].copy_(state_tensor)
|
|
next_state_offset += shard_size
|
|
|
|
return compacted_states
|
|
|
|
def load_from_compacted_states(
|
|
self,
|
|
compacted_states: torch.Tensor,
|
|
collected_states: dict,
|
|
state_names: list,
|
|
shard_start: int,
|
|
shard_size: int,
|
|
):
|
|
"""
|
|
Given a tensor carrying compacted optimizer states,
|
|
update these states to collected_states.
|
|
"""
|
|
shard_end = shard_start + shard_size
|
|
next_state_offset = 0
|
|
|
|
for state_name in state_names:
|
|
if state_name == "step":
|
|
collected_states["step"].data = torch.tensor(
|
|
compacted_states[next_state_offset].item(), dtype=torch.float32, requires_grad=False
|
|
).cpu()
|
|
next_state_offset += 1
|
|
else:
|
|
target_segment = collected_states[state_name][shard_start:shard_end]
|
|
target_segment.copy_(compacted_states[next_state_offset : next_state_offset + shard_size])
|
|
next_state_offset += shard_size
|
|
|
|
def get_param_groups_for_saving(self) -> list:
|
|
"""
|
|
Return the param_groups in Pytorch format when saving to checkpoint.
|
|
"""
|
|
|
|
param_groups = copy.deepcopy(self.param_groups_backup)
|
|
|
|
# To be compatible with pytorch checkpointing,
|
|
# store extra hyperparameters used by pytorch Adam optimizer.
|
|
torch_special_hyperparameters = {
|
|
"amsgrad": False,
|
|
"maximize": False,
|
|
"foreach": None,
|
|
"capturable": False,
|
|
"differentiable": False,
|
|
"fused": False,
|
|
}
|
|
|
|
for group in param_groups:
|
|
for k, v in torch_special_hyperparameters.items():
|
|
if k not in group:
|
|
group[k] = v
|
|
|
|
return param_groups
|
|
|
|
def state_dict(self, only_rank_0: bool = True) -> dict:
|
|
"""
|
|
Args:
|
|
only_rank_0 (bool): a boolean value indicating whether the state_dict is collected
|
|
only on rank 0, dafault to True.
|
|
|
|
Returns:
|
|
The complete state of the optimizer as a :class:`dict`.
|
|
It contains two entries:
|
|
|
|
* state - a dict holding current optimization state. Its content
|
|
differs between optimizer classes.
|
|
* param_groups - a list containing all parameter groups where each
|
|
parameter group is a dict.
|
|
|
|
Warning: This method will gather and return the whole optimizer state_dict,
|
|
so it should be called only when memory resources are abundant.
|
|
"""
|
|
state_dict = {}
|
|
state_dict["param_groups"] = self.get_param_groups_for_saving()
|
|
|
|
# Collect optimizer states.
|
|
state_dict["state"] = dict()
|
|
for param_id in self.id_to_real_params.keys():
|
|
dist.barrier()
|
|
state_dict["state"][param_id] = self.collect_states(param_id=param_id, only_rank_0=only_rank_0)
|
|
return state_dict
|
|
|
|
def load_param_groups(self, saved_param_groups: list):
|
|
"""
|
|
Load saved_param_groups into
|
|
self.param_groups and self.param_groups_backup
|
|
"""
|
|
self.param_groups_backup = copy.deepcopy(saved_param_groups)
|
|
|
|
# discard the older param_groups
|
|
self.optim.param_groups = []
|
|
|
|
for group in saved_param_groups:
|
|
fake_params_list = list()
|
|
updated_group = {k: v for k, v in group.items() if k != "params"}
|
|
for param_id in group["params"]:
|
|
if param_id not in self.id_to_fake_params:
|
|
continue
|
|
fake_param = self.id_to_fake_params[param_id]
|
|
fake_params_list.append(fake_param)
|
|
updated_group["params"] = fake_params_list
|
|
self.optim.param_groups.append(updated_group)
|
|
|
|
def load_single_param_states(self, param_id: int, saved_states: dict):
|
|
"""
|
|
Load saved optimizer states into parameter with given id.
|
|
"""
|
|
|
|
def cast(param, state_range, value, key=None):
|
|
"""
|
|
Make a copy of the needed segment of value and cast it to device of param.
|
|
"""
|
|
assert isinstance(value, torch.Tensor)
|
|
ret_val = value
|
|
if key == "step":
|
|
assert value.numel() == 1
|
|
ret_val = int(value.item())
|
|
else:
|
|
state_start, state_end = state_range
|
|
ret_val = torch.zeros(
|
|
state_end - state_start, dtype=torch.float32, device=param.device, requires_grad=False
|
|
)
|
|
|
|
if is_dtensor:
|
|
value = torch.reshape(value, global_shape)
|
|
value = distribute_tensor(value, sharding_spec=shard_spec, device_mesh=device_mesh)
|
|
elif is_customized_distributed:
|
|
value = torch.reshape(value, global_shape)
|
|
value = distribute_tensor_with_customization(value, real_param.shard_fn, real_param.gather_fn)
|
|
|
|
ret_val.copy_(value.flatten()[state_start:state_end])
|
|
return ret_val
|
|
|
|
assert param_id in self.id_to_fake_params
|
|
fake_param = self.id_to_fake_params[param_id]
|
|
_, state_offset, param_size = self.get_offsets(param_id)
|
|
state_range = (state_offset, state_offset + param_size)
|
|
|
|
# Copy states assigned to param (and cast tensors to appropriate types).
|
|
updated_states = dict()
|
|
|
|
# get tensor parallelism information
|
|
real_param = self.id_to_real_params[param_id]
|
|
is_dtensor = is_distributed_tensor(real_param)
|
|
is_customized_distributed = is_customized_distributed_tensor(real_param)
|
|
shard_spec = get_sharding_spec(real_param) if is_dtensor else None
|
|
device_mesh = get_device_mesh(real_param) if is_dtensor else None
|
|
global_shape = self.optimizer_params_info["id2shape"][param_id]
|
|
|
|
for k, v in saved_states.items():
|
|
updated_states[k] = cast(fake_param, state_range, v, k)
|
|
del v # clean loaded states
|
|
self.optim.state[fake_param].update(updated_states)
|
|
|
|
def load_param_states(self, param_states: dict):
|
|
"""Loads param states from a state_dict. The param_states can be complete or sharded.
|
|
During loading, filter out the part of states not considered by current process.
|
|
|
|
Args:
|
|
param_states (dict): A mapping from param_id to its states.
|
|
"""
|
|
for param_id, states in param_states.items():
|
|
if param_id in self.id_to_fake_params:
|
|
self.load_single_param_states(param_id, states)
|
|
|
|
def optimizer_loading_epilogue(self):
|
|
# Epilogue when loading state_dict to pytorch optimizer.
|
|
if Version(torch.__version__) >= Version("2.0.0"):
|
|
self.optim._patch_step_function() # To support multiprocessing pickle/unpickle
|
|
else:
|
|
self.optim._hook_for_profile() # To support multiprocessing pickle/unpickle.
|
|
self.optim.defaults.setdefault("differentiable", False)
|
|
|
|
def load_state_dict(self, state_dict: dict):
|
|
"""Loads optimizer state from complete optimizer state_dict.
|
|
During loading, filter out the part of states not considered by current process.
|
|
|
|
Args:
|
|
state_dict (dict): optimizer state. Should be an object returned
|
|
from a call to :meth:`state_dict`.
|
|
"""
|
|
assert "param_groups" in state_dict
|
|
assert "state" in state_dict
|
|
self.load_param_groups(state_dict["param_groups"])
|
|
self.load_param_states(state_dict["state"])
|
|
self.optimizer_loading_epilogue()
|
|
|
|
def state_shard(
|
|
self, prefix: str = "", max_shard_size: int = 1024, only_rank_0: bool = True
|
|
) -> Iterator[Tuple[OrderedDict, int]]:
|
|
"""Returns dictionaries containing shards of optimizer states one by one.
|
|
The max size of each dictionary shard is specified by ``max_shard_size``.
|
|
|
|
Args:
|
|
prefix (str, optional): the prefix for states. Default to ''.
|
|
max_shard_size (int, optional): max size of state dict shard (in MB). Defaults to 1024.
|
|
only_rank_0 (bool, optional): a boolean value indicating whether the state_dict is collected
|
|
only on rank 0, dafault to True.
|
|
|
|
Yields:
|
|
Iterator[OrderedDict]: A generator of state dict shard of optimizer states.
|
|
"""
|
|
|
|
sharder = StateDictSharder(max_shard_size)
|
|
for param_id in self.id_to_real_params.keys():
|
|
dist.barrier()
|
|
state = self.collect_states(param_id=param_id, only_rank_0=only_rank_0)
|
|
|
|
block, block_size = sharder.append_optim_state(param_id, state)
|
|
if block is not None:
|
|
yield block, block_size
|
|
|
|
yield sharder.current_block, sharder.current_block_size
|
|
|
|
def clip_grad_by_value(self, clip_value: float, *args, **kwargs) -> None:
|
|
raise NotImplementedError("Gemini does not support clip_grad_by_value")
|
|
|
|
def clip_grad_by_norm(
|
|
self,
|
|
max_norm: Union[float, int],
|
|
norm_type: Union[float, int] = 2,
|
|
error_if_nonfinite: bool = False,
|
|
*args,
|
|
**kwargs,
|
|
) -> torch.Tensor:
|
|
warnings.warn(f"Gemini controls grad clipping by itself, so you should not use clip_grad_by_norm")
|
|
|
|
|
|
class GeminiAdamOptimizer(GeminiOptimizer):
|
|
def __init__(self, model: torch.nn.Module, **defaults: Any) -> None:
|
|
optimizer = HybridAdam(model.parameters(), **defaults)
|
|
super().__init__(optimizer, model, **defaults)
|