InternLM/internlm/core/naive_amp.py

131 lines
5.0 KiB
Python
Raw Normal View History

2023-07-06 04:55:23 +00:00
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
# adopted from https://github.com/hpcaitech/ColossalAI/tree/main/colossalai/amp
from typing import Any
import torch
import torch.distributed as dist
from torch import Tensor, nn
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from torch.distributed import ReduceOp
from internlm.core.context import ParallelMode
from internlm.core.context.parallel_context import global_context as gpc
class NaiveAMPModel(nn.Module):
"""
This is a wrapper class for a model that automatically casts the model, its inputs, and outputs into fp16.
It also provides options to cast the output back to fp32 and to synchronize buffers.
Args:
model (torch.nn.Module): The model to be wrapped and cast into fp16.
output_to_fp32 (bool, optional): If True, the output of this module is cast into fp32. Defaults to True.
parallel_mode (:class:`internlm.core.context.ParallelMode`): The parallel group mode used in this module.
Defaults to ``ParallelMode.DATA``.
sync_buffer (bool, optional): If True, the buffers are synchronized. Defaults to True.
"""
def __init__(
self,
model: nn.Module,
output_to_fp32: bool = True,
parallel_mode: ParallelMode = ParallelMode.DATA,
sync_buffer: bool = True,
dtype=torch.float16,
):
super().__init__()
self.model = model.to(dtype)
self._output_to_fp32 = output_to_fp32
self._sync_buf = sync_buffer
self.dtype = dtype
if gpc.is_initialized(parallel_mode) and gpc.get_world_size(parallel_mode) > 1:
self._process_group = gpc.get_group(parallel_mode)
self._world_size = gpc.get_world_size(parallel_mode)
else:
self._process_group = None
self._world_size = 1
self._sync_buf = False
self._first_eval_run = False
@property
def sync_buffer(self):
"""Returns the current state of the buffer synchronization."""
return self._sync_buf
@sync_buffer.setter
def sync_buffer(self, state: bool):
"""Sets the state of the buffer synchronization."""
self._sync_buf = state
def _convert_to_fp16(self, input_: Any):
"""Converts the input to fp16 if it is a Tensor of dtype float32."""
if isinstance(input_, Tensor) and input_.dtype == torch.float32:
input_ = input_.to(self.dtype)
return input_
def _convert_to_fp32(self, input_: Any):
"""Converts the input to fp32 if it is a Tensor of dtype float16."""
if isinstance(input_, Tensor) and input_.dtype == torch.float16:
input_ = input_.float()
return input_
def _reduce_module_buffer(self):
"""
All-reduces the buffers (e.g., running stats of batch normalization) across
data parallel ranks so that all the ranks will produce consistent results
when given the same input.
"""
buf_list = []
# find valid buffers
for buf in self.model.buffers():
if buf is not None:
buf_list.append(buf)
# reduce buffers across data parallel ranks
if buf_list:
coalesced_buf = _flatten_dense_tensors(buf_list)
coalesced_buf.div_(self._world_size)
dist.all_reduce(coalesced_buf, op=ReduceOp.SUM, group=self._process_group)
unflattened_buf_list = _unflatten_dense_tensors(coalesced_buf, buf_list)
for old, new in zip(buf_list, unflattened_buf_list):
old.copy_(new)
def eval(self):
"""Sets the model to evaluation mode. Buffers are only synchronized in the first eval iteration."""
self.model.eval()
self._first_eval_run = True
def forward(self, *args, **kwargs):
"""
Performs a forward pass on the model. Buffers are synchronized before the forward pass.
The inputs are converted to fp16 and the outputs are optionally converted back to fp32.
"""
if (self.training or self._first_eval_run) and self._sync_buf:
with torch.no_grad():
self._reduce_module_buffer()
if self._first_eval_run:
self._first_eval_run = False
if args:
args = [self._convert_to_fp16(arg) for arg in args]
if kwargs:
for k, v in kwargs.items():
kwargs[k] = self._convert_to_fp16(v)
out = self.model(*args, **kwargs)
if self._output_to_fp32:
if isinstance(out, Tensor):
out = self._convert_to_fp32(out)
elif isinstance(out, (tuple, list)):
out = [self._convert_to_fp32(val) for val in out]
elif isinstance(out, dict):
out = {key: self._convert_to_fp32(val) for key, val in out.items()}
return out