2021-10-28 16:21:23 +00:00
|
|
|
#!/usr/bin/env python
|
|
|
|
# -*- encoding: utf-8 -*-
|
|
|
|
|
|
|
|
import torch
|
2022-03-15 02:05:38 +00:00
|
|
|
import torch.distributed as dist
|
2022-11-08 07:07:02 +00:00
|
|
|
from torch.distributed import ProcessGroup
|
2021-10-28 16:21:23 +00:00
|
|
|
from torch.optim import Optimizer
|
2022-11-08 07:07:02 +00:00
|
|
|
|
2023-09-18 08:31:06 +00:00
|
|
|
from colossalai.amp.naive_amp.grad_scaler import BaseGradScaler
|
2024-01-25 09:01:48 +00:00
|
|
|
from colossalai.kernel.kernel_loader import FusedOptimizerLoader
|
2023-09-18 08:31:06 +00:00
|
|
|
from colossalai.legacy.context import ParallelMode
|
|
|
|
from colossalai.legacy.core import global_context as gpc
|
|
|
|
from colossalai.legacy.utils import clip_grad_norm_fp32, copy_tensor_parallel_attributes
|
Develop/experiments (#59)
* Add gradient accumulation, fix lr scheduler
* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)
* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes
* fixed trainer
* Revert "fixed trainer"
This reverts commit 2e0b0b76990e8d4e337add483d878c0f61cf5097.
* improved consistency between trainer, engine and schedule (#23)
Co-authored-by: 1SAA <c2h214748@gmail.com>
* Split conv2d, class token, positional embedding in 2d, Fix random number in ddp
Fix convergence in cifar10, Imagenet1000
* Integrate 1d tensor parallel in Colossal-AI (#39)
* fixed 1D and 2D convergence (#38)
* optimized 2D operations
* fixed 1D ViT convergence problem
* Feature/ddp (#49)
* remove redundancy func in setup (#19) (#20)
* use env to control the language of doc (#24) (#25)
* Support TP-compatible Torch AMP and Update trainer API (#27)
* Add gradient accumulation, fix lr scheduler
* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)
* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes
* fixed trainer
* Revert "fixed trainer"
This reverts commit 2e0b0b76990e8d4e337add483d878c0f61cf5097.
* improved consistency between trainer, engine and schedule (#23)
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>
* add an example of ViT-B/16 and remove w_norm clipping in LAMB (#29)
* add explanation for ViT example (#35) (#36)
* support torch ddp
* fix loss accumulation
* add log for ddp
* change seed
* modify timing hook
Co-authored-by: Frank Lee <somerlee.9@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
* Feature/pipeline (#40)
* remove redundancy func in setup (#19) (#20)
* use env to control the language of doc (#24) (#25)
* Support TP-compatible Torch AMP and Update trainer API (#27)
* Add gradient accumulation, fix lr scheduler
* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)
* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes
* fixed trainer
* Revert "fixed trainer"
This reverts commit 2e0b0b76990e8d4e337add483d878c0f61cf5097.
* improved consistency between trainer, engine and schedule (#23)
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>
* add an example of ViT-B/16 and remove w_norm clipping in LAMB (#29)
* add explanation for ViT example (#35) (#36)
* optimize communication of pipeline parallel
* fix grad clip for pipeline
Co-authored-by: Frank Lee <somerlee.9@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
* optimized 3d layer to fix slow computation ; tested imagenet performance with 3d; reworked lr_scheduler config definition; fixed launch args; fixed some printing issues; simplified apis of 3d layers (#51)
* Update 2.5d layer code to get a similar accuracy on imagenet-1k dataset
* update api for better usability (#58)
update api for better usability
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>
Co-authored-by: puck_WCR <46049915+WANG-CR@users.noreply.github.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
Co-authored-by: アマデウス <kurisusnowdeng@users.noreply.github.com>
Co-authored-by: BoxiangW <45734921+BoxiangW@users.noreply.github.com>
2021-12-09 07:08:29 +00:00
|
|
|
from colossalai.logging import get_dist_logger
|
2023-09-18 08:31:06 +00:00
|
|
|
from colossalai.utils import multi_tensor_applier
|
2022-11-08 07:07:02 +00:00
|
|
|
|
2022-03-15 02:05:38 +00:00
|
|
|
from ._utils import has_inf_or_nan, zero_gard_by_list
|
2021-10-28 16:21:23 +00:00
|
|
|
|
2023-01-06 12:50:26 +00:00
|
|
|
try:
|
|
|
|
from colossalai._C import fused_optim
|
|
|
|
except:
|
|
|
|
fused_optim = None
|
|
|
|
|
2023-09-19 06:20:26 +00:00
|
|
|
__all__ = ["FP16Optimizer"]
|
2021-10-28 16:21:23 +00:00
|
|
|
|
|
|
|
|
2023-01-06 12:50:26 +00:00
|
|
|
def load_fused_optim():
|
|
|
|
global fused_optim
|
|
|
|
|
|
|
|
if fused_optim is None:
|
2024-01-25 09:01:48 +00:00
|
|
|
fused_optim = FusedOptimizerLoader().load()
|
2023-01-06 12:50:26 +00:00
|
|
|
|
|
|
|
|
2021-10-28 16:21:23 +00:00
|
|
|
def _multi_tensor_copy_this_to_that(this, that, overflow_buf=None):
|
2022-03-15 02:05:38 +00:00
|
|
|
"""
|
|
|
|
adapted from Megatron-LM (https://github.com/NVIDIA/Megatron-LM)
|
|
|
|
|
|
|
|
Use multi-tensor-applier to copy values from one list to another.
|
2021-10-28 16:21:23 +00:00
|
|
|
We don't have a blfoat16 implementation so for now if the overflow_buf
|
|
|
|
is not provided, we default back to simple loop copy to be compatible
|
2022-03-15 02:05:38 +00:00
|
|
|
with bfloat16.
|
|
|
|
"""
|
2021-10-28 16:21:23 +00:00
|
|
|
if overflow_buf:
|
|
|
|
overflow_buf.fill_(0)
|
|
|
|
# Scaling with factor `1.0` is equivalent to copy.
|
2023-01-06 12:50:26 +00:00
|
|
|
global fused_optim
|
|
|
|
load_fused_optim()
|
2022-12-23 06:14:21 +00:00
|
|
|
multi_tensor_applier(fused_optim.multi_tensor_scale, overflow_buf, [this, that], 1.0)
|
2021-10-28 16:21:23 +00:00
|
|
|
else:
|
|
|
|
for this_, that_ in zip(this, that):
|
|
|
|
that_.copy_(this_)
|
|
|
|
|
|
|
|
|
|
|
|
class FP16Optimizer(Optimizer):
|
|
|
|
"""Float16 optimizer for fp16 and bf16 data types.
|
2022-11-08 07:07:02 +00:00
|
|
|
|
2022-03-25 05:02:39 +00:00
|
|
|
Args:
|
|
|
|
optimizer (torch.optim.Optimizer): base optimizer such as Adam or SGD
|
|
|
|
grad_scaler (BaseGradScaler): grad scaler for gradient chose in
|
|
|
|
``constant_grad_scaler`` or ``dynamic_grad_scaler``.
|
|
|
|
clip_grad_norm (float, optional): clip gradients with this global L2 norm. Default 0.
|
|
|
|
Note that clipping is ignored if clip_grad == 0
|
|
|
|
verbose (bool, optional): if set to `True`, will print debug info. Default False.
|
2021-10-28 16:21:23 +00:00
|
|
|
"""
|
|
|
|
|
2023-09-19 06:20:26 +00:00
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
optimizer: Optimizer,
|
|
|
|
grad_scaler: BaseGradScaler,
|
|
|
|
verbose: bool = False,
|
|
|
|
clip_grad_norm=0,
|
|
|
|
dp_process_group: ProcessGroup = None,
|
|
|
|
mp_process_group: ProcessGroup = None,
|
|
|
|
):
|
2021-10-28 16:21:23 +00:00
|
|
|
# have a defaults for compatibility with pytorch optim
|
2022-03-15 02:05:38 +00:00
|
|
|
self._optimizer = optimizer
|
|
|
|
self._defaults = optimizer.defaults
|
|
|
|
|
|
|
|
# fp16-related params
|
|
|
|
assert isinstance(grad_scaler, BaseGradScaler)
|
|
|
|
self._grad_scaler = grad_scaler
|
|
|
|
self._found_overflow = torch.cuda.FloatTensor([0.0])
|
|
|
|
self._dummy_overflow_buf = torch.cuda.IntTensor([0])
|
|
|
|
|
|
|
|
# misc params
|
|
|
|
self._clip_grad_max_norm = clip_grad_norm
|
|
|
|
|
|
|
|
# get process group
|
|
|
|
def _get_process_group(parallel_mode):
|
2023-01-03 08:53:43 +00:00
|
|
|
if gpc.is_initialized(parallel_mode) and gpc.get_world_size(parallel_mode):
|
|
|
|
return gpc.get_group(parallel_mode)
|
2022-03-15 02:05:38 +00:00
|
|
|
else:
|
|
|
|
return None
|
|
|
|
|
|
|
|
if dp_process_group is None:
|
|
|
|
dp_process_group = _get_process_group(ParallelMode.DATA)
|
|
|
|
if mp_process_group is None:
|
|
|
|
mp_process_group = _get_process_group(ParallelMode.MODEL)
|
|
|
|
|
|
|
|
self._dp_process_group = dp_process_group
|
|
|
|
self._mp_process_group = mp_process_group
|
|
|
|
|
|
|
|
# we maintain three groups of parameters
|
|
|
|
# so that the model can have a mixture
|
|
|
|
# of fp16 and fp32 params
|
|
|
|
# fp16_param_groups: the fp16 params of the model
|
|
|
|
# fp32_master_param_groups: the fp32 params cast from the fp16 param of the model
|
|
|
|
# fp32_param_groups: the fp32 params of the model
|
|
|
|
# NOTE:
|
|
|
|
# 1. fp16_param_groups and fp32_master_param_groups have one-to-one correspondence
|
|
|
|
# 2. fp32_param_groups and fp16_param_groups are exclusive of each other
|
|
|
|
self._fp16_param_groups = []
|
|
|
|
self._fp32_master_param_groups = []
|
|
|
|
self._fp32_param_groups = []
|
2021-10-28 16:21:23 +00:00
|
|
|
|
|
|
|
# For all the groups in the original optimizer:
|
2022-03-15 02:05:38 +00:00
|
|
|
for param_group in self._optimizer.param_groups:
|
|
|
|
fp16_params = []
|
|
|
|
fp32_master_params = []
|
|
|
|
fp32_params = []
|
2021-10-28 16:21:23 +00:00
|
|
|
# For all the parameters in this group:
|
2023-09-19 06:20:26 +00:00
|
|
|
for i, param in enumerate(param_group["params"]):
|
2021-10-28 16:21:23 +00:00
|
|
|
if param.requires_grad:
|
|
|
|
# float16 params:
|
2023-09-19 06:20:26 +00:00
|
|
|
if param.type() in ["torch.cuda.HalfTensor"]:
|
2022-03-15 02:05:38 +00:00
|
|
|
fp16_params.append(param)
|
2021-10-28 16:21:23 +00:00
|
|
|
|
2022-03-15 02:05:38 +00:00
|
|
|
# Create a fp32 copy
|
|
|
|
fp32_param = param.detach().clone().float()
|
|
|
|
# Copy tensor model parallel attributes.
|
|
|
|
copy_tensor_parallel_attributes(param, fp32_param)
|
2021-10-28 16:21:23 +00:00
|
|
|
|
|
|
|
# Replace the optimizer params with the new fp32 copy.
|
2023-09-19 06:20:26 +00:00
|
|
|
param_group["params"][i] = fp32_param
|
2022-03-15 02:05:38 +00:00
|
|
|
fp32_master_params.append(fp32_param)
|
|
|
|
|
2021-10-28 16:21:23 +00:00
|
|
|
# Reset existing state dict key to the new main param.
|
2022-03-15 02:05:38 +00:00
|
|
|
if param in self._optimizer.state:
|
|
|
|
self._optimizer.state[fp32_param] = self._optimizer.state.pop(param)
|
2021-10-28 16:21:23 +00:00
|
|
|
|
|
|
|
# fp32 params.
|
2023-09-19 06:20:26 +00:00
|
|
|
elif param.type() == "torch.cuda.FloatTensor":
|
2022-03-15 02:05:38 +00:00
|
|
|
fp32_params.append(param)
|
2021-10-28 16:21:23 +00:00
|
|
|
else:
|
2023-09-19 06:20:26 +00:00
|
|
|
raise TypeError(
|
|
|
|
"Expected parameter of type torch.cuda.FloatTensor "
|
|
|
|
f"or torch.cuda.HalfTensor, but got {param.type()}"
|
|
|
|
)
|
2021-10-28 16:21:23 +00:00
|
|
|
|
2022-03-15 02:05:38 +00:00
|
|
|
self._fp16_param_groups.append(fp16_params)
|
|
|
|
self._fp32_master_param_groups.append(fp32_master_params)
|
|
|
|
self._fp32_param_groups.append(fp32_params)
|
2021-10-28 16:21:23 +00:00
|
|
|
|
|
|
|
# Leverage state_dict() and load_state_dict() to
|
|
|
|
# recast preexisting per-param state tensors
|
2022-03-15 02:05:38 +00:00
|
|
|
self._optimizer.load_state_dict(self._optimizer.state_dict())
|
2021-10-28 16:21:23 +00:00
|
|
|
|
2022-03-15 02:05:38 +00:00
|
|
|
# log config
|
|
|
|
self._logger = get_dist_logger()
|
|
|
|
if verbose:
|
|
|
|
self._logger.info(
|
|
|
|
f"\n========= FP16 Optimizer Config =========\n"
|
|
|
|
f"Optimizer: {optimizer.__class__.__name__}\n"
|
|
|
|
f"clip_grad_norm = {clip_grad_norm}\n"
|
|
|
|
f"grad_scaler = {self._grad_scaler.__class__.__name__}"
|
|
|
|
f"==========================================",
|
2023-09-19 06:20:26 +00:00
|
|
|
ranks=[0],
|
|
|
|
)
|
2021-10-28 16:21:23 +00:00
|
|
|
|
2023-01-04 03:59:56 +00:00
|
|
|
@property
|
|
|
|
def max_norm(self):
|
2023-09-19 06:20:26 +00:00
|
|
|
"""Returns the maximum norm of gradient clipping."""
|
2023-01-04 03:59:56 +00:00
|
|
|
return self._clip_grad_max_norm
|
|
|
|
|
2022-03-15 02:05:38 +00:00
|
|
|
@property
|
|
|
|
def grad_scaler(self):
|
2022-04-25 05:42:17 +00:00
|
|
|
"""Returns the gradient scaler.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
:class:`BaseGradScaler`: gradient scaler.
|
|
|
|
"""
|
|
|
|
|
2022-03-15 02:05:38 +00:00
|
|
|
return self._grad_scaler
|
2021-10-28 16:21:23 +00:00
|
|
|
|
2022-03-15 02:05:38 +00:00
|
|
|
@property
|
|
|
|
def loss_scale(self):
|
2022-04-25 05:42:17 +00:00
|
|
|
"""Returns the loss scale.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
int: loss scale.
|
|
|
|
"""
|
2022-03-15 02:05:38 +00:00
|
|
|
return self._grad_scaler.scale
|
2021-10-28 16:21:23 +00:00
|
|
|
|
2022-03-15 02:05:38 +00:00
|
|
|
@property
|
|
|
|
def optimizer(self):
|
2022-04-25 05:42:17 +00:00
|
|
|
"""Returns the optimizer.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
:class:`torch.optim.Optimizer`: the optimizer object wrapped.
|
|
|
|
"""
|
2022-03-15 02:05:38 +00:00
|
|
|
return self._optimizer
|
|
|
|
|
|
|
|
@property
|
|
|
|
def defaults(self):
|
2022-04-25 05:42:17 +00:00
|
|
|
"""Returns the default arguments of optimizer.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
dict: optimizer arguments saved in defaults of the optimizer wrapped.
|
|
|
|
"""
|
2022-03-15 02:05:38 +00:00
|
|
|
return self._defaults
|
2021-10-28 16:21:23 +00:00
|
|
|
|
2022-03-15 02:05:38 +00:00
|
|
|
def _check_overflow(self):
|
|
|
|
# clear previous overflow record
|
|
|
|
self._found_overflow.fill_(0.0)
|
2021-10-28 16:21:23 +00:00
|
|
|
|
2022-03-15 02:05:38 +00:00
|
|
|
# check for overflow
|
|
|
|
for group in self._optimizer.param_groups:
|
2023-09-19 06:20:26 +00:00
|
|
|
for p in group["params"]:
|
2022-03-16 06:35:46 +00:00
|
|
|
if p.grad is not None and has_inf_or_nan(p.grad):
|
2022-03-15 02:05:38 +00:00
|
|
|
self._found_overflow.fill_(1.0)
|
|
|
|
break
|
2021-10-28 16:21:23 +00:00
|
|
|
|
2022-03-15 02:05:38 +00:00
|
|
|
# all-reduce across dp group
|
|
|
|
if self._dp_process_group:
|
|
|
|
dist.all_reduce(self._found_overflow, op=dist.ReduceOp.MAX, group=self._dp_process_group)
|
|
|
|
|
|
|
|
# all-reduce over model parallel group
|
|
|
|
if self._mp_process_group:
|
|
|
|
dist.all_reduce(self._found_overflow, op=dist.ReduceOp.MAX, group=self._mp_process_group)
|
|
|
|
|
|
|
|
return self._found_overflow.item() > 0
|
|
|
|
|
|
|
|
def zero_grad(self, set_to_none=True):
|
2022-04-25 05:42:17 +00:00
|
|
|
"""Set gradient to zero.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
set_to_none (bool): Whether set the gradient to None.
|
|
|
|
"""
|
|
|
|
|
2022-03-15 02:05:38 +00:00
|
|
|
# set_to_none = True can save some memory space
|
|
|
|
for param_group in self._optimizer.param_groups:
|
2023-09-19 06:20:26 +00:00
|
|
|
zero_gard_by_list(param_group["params"], set_to_none=set_to_none)
|
2022-03-15 02:05:38 +00:00
|
|
|
|
|
|
|
def _get_fp32_param_groups_to_update(self):
|
|
|
|
return self._fp32_master_param_groups + self._fp32_param_groups
|
|
|
|
|
|
|
|
def _unscale_grads(self):
|
|
|
|
for group in self._get_fp32_param_groups_to_update():
|
|
|
|
for p in group:
|
|
|
|
if p.grad is not None:
|
|
|
|
p.grad.data.div_(self.loss_scale)
|
|
|
|
|
|
|
|
def _assign_grad_to_fp32_master_param(self):
|
|
|
|
# This only needs to be done for the float16 group.
|
|
|
|
for fp16_param_group, fp32_master_param_group in zip(self._fp16_param_groups, self._fp32_master_param_groups):
|
|
|
|
for fp16_param, fp32_param in zip(fp16_param_group, fp32_master_param_group):
|
2022-03-16 06:35:46 +00:00
|
|
|
if fp16_param.grad is not None:
|
|
|
|
fp32_param.grad = fp16_param.grad.float()
|
|
|
|
# clear unneeded grad on fp16 param
|
|
|
|
fp16_param.grad = None
|
2022-03-15 02:05:38 +00:00
|
|
|
|
|
|
|
def _update_fp16_param_from_fp32_param(self):
|
|
|
|
fp16_param_data = []
|
|
|
|
fp32_master_param_data = []
|
|
|
|
for fp16_group, fp32_group in zip(self._fp16_param_groups, self._fp32_master_param_groups):
|
|
|
|
for fp16_param, fp32_param in zip(fp16_group, fp32_group):
|
|
|
|
fp16_param_data.append(fp16_param.data)
|
|
|
|
fp32_master_param_data.append(fp32_param.data)
|
2023-09-19 06:20:26 +00:00
|
|
|
_multi_tensor_copy_this_to_that(
|
|
|
|
this=fp32_master_param_data, that=fp16_param_data, overflow_buf=self._dummy_overflow_buf
|
|
|
|
)
|
2022-03-15 02:05:38 +00:00
|
|
|
|
|
|
|
def step(self):
|
2023-09-19 06:20:26 +00:00
|
|
|
"""Update the model parameters."""
|
2022-04-25 05:42:17 +00:00
|
|
|
|
2022-03-15 02:05:38 +00:00
|
|
|
# Copy gradients from model params to main params.
|
|
|
|
self._assign_grad_to_fp32_master_param()
|
|
|
|
self._unscale_grads()
|
|
|
|
|
|
|
|
overflow = self._check_overflow()
|
|
|
|
self._grad_scaler.update(overflow)
|
|
|
|
if overflow:
|
|
|
|
self.zero_grad()
|
2021-10-28 16:21:23 +00:00
|
|
|
|
|
|
|
# Clip the main gradients.
|
|
|
|
grad_norm = None
|
2022-03-15 02:05:38 +00:00
|
|
|
if self._clip_grad_max_norm > 0.0:
|
|
|
|
grad_norm = self.clip_grad_norm(self._clip_grad_max_norm)
|
2021-10-28 16:21:23 +00:00
|
|
|
|
2022-06-27 01:53:57 +00:00
|
|
|
if not overflow:
|
|
|
|
# Step the optimizer.
|
|
|
|
self._optimizer.step()
|
2021-10-28 16:21:23 +00:00
|
|
|
|
2022-06-27 01:53:57 +00:00
|
|
|
# Update params from main params.
|
|
|
|
self._update_fp16_param_from_fp32_param()
|
2021-10-28 16:21:23 +00:00
|
|
|
|
2022-06-27 01:53:57 +00:00
|
|
|
# Successful update.
|
|
|
|
return True, grad_norm
|
|
|
|
else:
|
|
|
|
return False, None
|
2022-03-15 02:05:38 +00:00
|
|
|
|
|
|
|
def backward(self, loss):
|
2022-04-25 05:42:17 +00:00
|
|
|
"""Execute backward pass.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
loss (:class:`torch.Tensor`): the loss value.
|
|
|
|
"""
|
|
|
|
|
2022-03-15 02:05:38 +00:00
|
|
|
scaled_loss = loss * self.grad_scaler.scale
|
|
|
|
scaled_loss.backward()
|
2021-10-28 16:21:23 +00:00
|
|
|
|
|
|
|
def state_dict(self):
|
2023-09-19 06:20:26 +00:00
|
|
|
"""Returns the states of the fp16 optimizer as a dict object."""
|
2022-04-25 05:42:17 +00:00
|
|
|
|
2021-10-28 16:21:23 +00:00
|
|
|
state_dict = {}
|
2023-09-19 06:20:26 +00:00
|
|
|
state_dict["optimizer"] = self._optimizer.state_dict()
|
2021-10-28 16:21:23 +00:00
|
|
|
if self.grad_scaler:
|
2023-09-19 06:20:26 +00:00
|
|
|
state_dict["grad_scaler"] = self.grad_scaler.state_dict()
|
|
|
|
state_dict["fp32_master_param_groups"] = self._fp32_master_param_groups
|
2021-10-28 16:21:23 +00:00
|
|
|
return state_dict
|
|
|
|
|
|
|
|
def load_state_dict(self, state_dict):
|
2022-04-25 05:42:17 +00:00
|
|
|
"""Load the states of the fp16 optimizer from a dict object.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
state_dict (dict): the states of the fp16 optimizer
|
|
|
|
"""
|
|
|
|
|
2021-10-28 16:21:23 +00:00
|
|
|
# Optimizer.
|
2023-09-19 06:20:26 +00:00
|
|
|
self._optimizer.load_state_dict(state_dict["optimizer"])
|
2021-10-28 16:21:23 +00:00
|
|
|
|
|
|
|
# Grad scaler.
|
2023-09-19 06:20:26 +00:00
|
|
|
if "grad_scaler" in state_dict:
|
|
|
|
self.grad_scaler.load_state_dict(state_dict["grad_scaler"])
|
2021-10-28 16:21:23 +00:00
|
|
|
|
|
|
|
# Copy data for the main params.
|
2023-09-19 06:20:26 +00:00
|
|
|
if "fp32_master_param_groups" in state_dict:
|
|
|
|
for current_group, ckpt_group in zip(
|
|
|
|
self._fp32_master_param_groups, state_dict["fp32_master_param_groups"]
|
|
|
|
):
|
2022-03-15 02:05:38 +00:00
|
|
|
for current_param, ckpt_param in zip(current_group, ckpt_group):
|
|
|
|
current_param.data.copy_(ckpt_param.data)
|
|
|
|
|
|
|
|
def clip_grad_norm(self, clip_grad):
|
2022-04-25 05:42:17 +00:00
|
|
|
"""Clip gradients by norm.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
clip_grad (float): the max norm for clipping
|
|
|
|
"""
|
2021-10-28 16:21:23 +00:00
|
|
|
params = []
|
2022-03-15 02:05:38 +00:00
|
|
|
for param_group in self._optimizer.param_groups:
|
2023-09-19 06:20:26 +00:00
|
|
|
for param in param_group["params"]:
|
2021-10-28 16:21:23 +00:00
|
|
|
params.append(param)
|
|
|
|
return clip_grad_norm_fp32(params, clip_grad)
|
|
|
|
|
|
|
|
# Promote state so it can be retrieved or set via
|
|
|
|
# "optimizer_instance.state"
|
|
|
|
def _get_state(self):
|
2022-03-15 02:05:38 +00:00
|
|
|
return self._optimizer.state
|
2021-10-28 16:21:23 +00:00
|
|
|
|
|
|
|
def _set_state(self, value):
|
2022-03-15 02:05:38 +00:00
|
|
|
self._optimizer.state = value
|
2021-10-28 16:21:23 +00:00
|
|
|
|
|
|
|
state = property(_get_state, _set_state)
|
|
|
|
|
|
|
|
# Promote param_groups so it can be retrieved or set via
|
|
|
|
# "optimizer_instance.param_groups"
|
|
|
|
# (for example, to adjust the learning rate)
|
|
|
|
def _get_param_groups(self):
|
2022-03-15 02:05:38 +00:00
|
|
|
return self._optimizer.param_groups
|
2021-10-28 16:21:23 +00:00
|
|
|
|
|
|
|
def _set_param_groups(self, value):
|
2022-03-15 02:05:38 +00:00
|
|
|
self._optimizer.param_groups = value
|
2021-10-28 16:21:23 +00:00
|
|
|
|
|
|
|
param_groups = property(_get_param_groups, _set_param_groups)
|