mirror of https://github.com/hpcaitech/ColossalAI
86 lines
3.7 KiB
Python
86 lines
3.7 KiB
Python
|
from typing import Optional
|
||
|
|
||
|
import torch
|
||
|
import torch.nn as nn
|
||
|
from torch.nn.parameter import Parameter
|
||
|
|
||
|
|
||
|
class ShardedGradient:
|
||
|
def __init__(self,
|
||
|
param: Parameter,
|
||
|
sharded_module: nn.Module,
|
||
|
offload_config: Optional[dict] = None
|
||
|
) -> None:
|
||
|
assert hasattr(
|
||
|
param, 'ca_attr') and param.ca_attr.is_sharded, 'ShardedGradient can only be initialized with sharded parameter'
|
||
|
|
||
|
self.param = param
|
||
|
self.sharded_module = sharded_module
|
||
|
self.offload_config = offload_config
|
||
|
|
||
|
self._cpu_offload = offload_config.get('device', None) == 'cpu' if offload_config else False
|
||
|
|
||
|
# _gpu_grad is either sharded or not
|
||
|
# all saved grads are fp32
|
||
|
self._gpu_grad: Optional[torch.Tensor] = None
|
||
|
self._cpu_grad: Optional[torch.Tensor] = None
|
||
|
|
||
|
if self._cpu_offload:
|
||
|
# this buffer will be held and reused every iteration
|
||
|
self._cpu_grad = torch.zeros(param.ca_attr.payload('cpu'), dtype=torch.float).pin_memory()
|
||
|
|
||
|
@torch.no_grad()
|
||
|
def setup(self) -> None:
|
||
|
"""This function will be called pre-backward. Save the local accumulated gradient to _gpu_grad.
|
||
|
When no_sync() is enable (_require_backward_grad_sync=False), the grad is accumulated locally in param.grad
|
||
|
|
||
|
:raises AssertionError: Raise if grad shape is wrong
|
||
|
"""
|
||
|
if self.sharded_module._require_backward_grad_sync and self.param.grad is not None:
|
||
|
if self.param.grad.device != self.param.data.device:
|
||
|
# TODO: offload?
|
||
|
raise RuntimeError(
|
||
|
'grad and param are on different device, grad {self.param.grad.device} vs. param {self.param.data.device}')
|
||
|
else:
|
||
|
self._gpu_grad = self.param.grad.data
|
||
|
self.param.grad = None
|
||
|
|
||
|
def reduce_scatter_callback(self, reduced_grad: torch.Tensor) -> None:
|
||
|
"""This function will be called in post-backward hook, so we cannot modify param.grad directly
|
||
|
|
||
|
:param reduced_grad: the reduced grad
|
||
|
:type reduced_grad: torch.Tensor
|
||
|
"""
|
||
|
# Make sure we store fp32 grad
|
||
|
if torch.is_floating_point(reduced_grad) and reduced_grad.dtype != torch.float:
|
||
|
reduced_grad.data = reduced_grad.data.to(torch.float)
|
||
|
|
||
|
if self._gpu_grad is None:
|
||
|
self._gpu_grad = reduced_grad.data
|
||
|
else:
|
||
|
self._gpu_grad += reduced_grad.data
|
||
|
|
||
|
# Optionally move gradients to CPU, typically used if one is running the optimizer on the CPU. Once the full
|
||
|
# backwards pass completes, we will set `.grad` to the CPU copy.
|
||
|
if self._cpu_offload:
|
||
|
self._cpu_grad.copy_(self._gpu_grad.data, non_blocking=True)
|
||
|
# Don't let this memory get reused until after the transfer.
|
||
|
self._gpu_grad.data.record_stream(torch.cuda.current_stream())
|
||
|
|
||
|
@torch.no_grad()
|
||
|
def write_back(self) -> None:
|
||
|
"""This function will be called in final backward hook
|
||
|
"""
|
||
|
if self._cpu_grad is not None:
|
||
|
assert self.param.device == torch.device(
|
||
|
'cpu'), f'Incorrect param device, expected CPU, got {self.param.device}'
|
||
|
self.param.grad.data = self._cpu_grad
|
||
|
elif self._gpu_grad is not None:
|
||
|
assert self.param.device == self._gpu_grad.device, f'Incorrect _gpu_grad device, param on {self.param.device} but _gpu_grad on {self._gpu_grad.device}'
|
||
|
self.param.grad.data = self._gpu_grad
|
||
|
else:
|
||
|
raise RuntimeError('No grad to write back')
|
||
|
# If using CPU offload, _cpu_grad will store the CPU tensor of _gpu_grad
|
||
|
# They should be released here
|
||
|
self._gpu_grad = None
|