mirror of https://github.com/hpcaitech/ColossalAI
aibig-modeldata-parallelismdeep-learningdistributed-computingfoundation-modelsheterogeneous-traininghpcinferencelarge-scalemodel-parallelismpipeline-parallelism
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
152 lines
5.5 KiB
152 lines
5.5 KiB
#!/usr/bin/env python |
|
# -*- encoding: utf-8 -*- |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.distributed as dist |
|
from torch import Tensor |
|
from typing import Any |
|
from torch.optim import Optimizer |
|
from torch.distributed import ReduceOp |
|
from colossalai.core import global_context as gpc |
|
from colossalai.context import ParallelMode |
|
from colossalai.nn.optimizer import ColossalaiOptimizer |
|
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors |
|
from ._fp16_optimizer import FP16Optimizer |
|
|
|
|
|
class NaiveAMPOptimizer(ColossalaiOptimizer): |
|
"""A wrapper class for optimizer to cast all parameters to fp16 |
|
|
|
Args: |
|
optim (torch.optim.Optimizer): A normal optimizer like Adam or SGD. |
|
grad_scaler (BaseGradScaler): grad scaler for gradient chose in |
|
``constant_grad_scaler`` or ``dynamic_grad_scaler``. |
|
clip_grad_norm (float, optional): clip gradients with this global L2 norm. Default 0. |
|
verbose (bool, optional): if set to `True`, will print debug info. Default False. |
|
|
|
Note: |
|
clipping is ignored if ``clip_grad_norm`` equals 0. |
|
""" |
|
|
|
def __init__(self, optim: Optimizer, *args, **kwargs): |
|
optim = FP16Optimizer(optim, *args, **kwargs) |
|
super().__init__(optim) |
|
|
|
def backward(self, loss: Tensor): |
|
self.optim.backward(loss) |
|
|
|
def step(self): |
|
return self.optim.step() |
|
|
|
def clip_grad_norm(self, model: nn.Module, max_norm: float): |
|
pass |
|
|
|
|
|
class NaiveAMPModel(nn.Module): |
|
r"""A wrapper class for model to cast the model into fp16 and |
|
automatically cast the input and output |
|
|
|
Args: |
|
model (torch.nn.Module): torch.nn.Module to be wrapped. |
|
output_to_fp32 (bool, optional): Whether cast output of this module into fp32. (Default: True) |
|
parallel_mode (:class:`colossalai.context.ParallelMode`): Parallel group mode used in this module. |
|
(Default: ``ParallelMode.DATA``) |
|
sync_buffer (bool, optional): whether to synchronize buffer. (Default: True) |
|
|
|
Note: |
|
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found |
|
in `parallel_mode <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/parallel_mode.py>`_. |
|
""" |
|
|
|
def __init__(self, |
|
model: nn.Module, |
|
output_to_fp32: bool = True, |
|
parallel_mode: ParallelMode = ParallelMode.DATA, |
|
sync_buffer: bool = True): |
|
super().__init__() |
|
self.model = model.half() |
|
self._output_to_fp32 = output_to_fp32 |
|
self._sync_buf = sync_buffer |
|
|
|
if gpc.is_initialized(parallel_mode) and gpc.get_world_size(parallel_mode) > 1: |
|
self._process_group = gpc.get_group(parallel_mode) |
|
self._world_size = gpc.get_world_size(parallel_mode) |
|
else: |
|
self._process_group = None |
|
self._world_size = 1 |
|
self._sync_buf = False |
|
self._first_eval_run = False |
|
|
|
@property |
|
def sync_buffer(self): |
|
return self._sync_buf |
|
|
|
@sync_buffer.setter |
|
def sync_buffer(self, state: bool): |
|
self._sync_buf = state |
|
|
|
def _convert_to_fp16(self, input_: Any): |
|
if isinstance(input_, Tensor) and input_.dtype == torch.float32: |
|
input_ = input_.half() |
|
return input_ |
|
|
|
def _convert_to_fp32(self, input_: Any): |
|
if isinstance(input_, Tensor) and input_.dtype == torch.float16: |
|
input_ = input_.float() |
|
return input_ |
|
|
|
def _reduce_module_buffer(self): |
|
""" |
|
All-reduce the buffers (e.g. running stats of batch normalization) across |
|
data parallel ranks so that all the ranks will produce consistent results |
|
when given the same input |
|
""" |
|
buf_list = [] |
|
|
|
# find valid buffers |
|
for buf in self.model.buffers(): |
|
if buf is not None: |
|
buf_list.append(buf) |
|
|
|
# reduce buffers across data parallel ranks |
|
if buf_list: |
|
coalesced_buf = _flatten_dense_tensors(buf_list) |
|
coalesced_buf.div_(self._world_size) |
|
dist.all_reduce(coalesced_buf, op=ReduceOp.SUM, group=self._process_group) |
|
unflattened_buf_list = _unflatten_dense_tensors(coalesced_buf, buf_list) |
|
for old, new in zip(buf_list, unflattened_buf_list): |
|
old.copy_(new) |
|
|
|
def eval(self): |
|
self.model.eval() |
|
|
|
# we only sync buffer in the first eval iteration |
|
# so that future eval iterations can be done without communication |
|
self._first_eval_run = True |
|
|
|
def forward(self, *args, **kwargs): |
|
# reduce buffers after forward will lead to error |
|
# as we cannot change the variables needed for gradient computation after forward |
|
# so we sync buffer before forward |
|
if (self.training or self._first_eval_run) and self._sync_buf: |
|
with torch.no_grad(): |
|
self._reduce_module_buffer() |
|
|
|
if self._first_eval_run: |
|
self._first_eval_run = False |
|
|
|
if args: |
|
args = [self._convert_to_fp16(arg) for arg in args] |
|
if kwargs: |
|
for k, v in kwargs.items(): |
|
kwargs[k] = self._convert_to_fp16(v) |
|
|
|
out = self.model(*args, **kwargs) |
|
|
|
if self._output_to_fp32: |
|
if isinstance(out, Tensor): |
|
out = self._convert_to_fp32(out) |
|
elif isinstance(out, (tuple, list)): |
|
out = [self._convert_to_fp32(val) for val in out] |
|
return out
|
|
|