ColossalAI/colossalai/zero/sharded_model/sharded_grad.py

86 lines
3.7 KiB
Python
Raw Normal View History

from typing import Optional
import torch
import torch.nn as nn
from torch.nn.parameter import Parameter
class ShardedGradient:
def __init__(self,
param: Parameter,
sharded_module: nn.Module,
offload_config: Optional[dict] = None
) -> None:
assert hasattr(
param, 'ca_attr') and param.ca_attr.is_sharded, 'ShardedGradient can only be initialized with sharded parameter'
self.param = param
self.sharded_module = sharded_module
self.offload_config = offload_config
self._cpu_offload = offload_config.get('device', None) == 'cpu' if offload_config else False
# _gpu_grad is either sharded or not
# all saved grads are fp32
self._gpu_grad: Optional[torch.Tensor] = None
self._cpu_grad: Optional[torch.Tensor] = None
if self._cpu_offload:
# this buffer will be held and reused every iteration
self._cpu_grad = torch.zeros(param.ca_attr.payload('cpu'), dtype=torch.float).pin_memory()
@torch.no_grad()
def setup(self) -> None:
"""This function will be called pre-backward. Save the local accumulated gradient to _gpu_grad.
When no_sync() is enable (_require_backward_grad_sync=False), the grad is accumulated locally in param.grad
:raises AssertionError: Raise if grad shape is wrong
"""
if self.sharded_module._require_backward_grad_sync and self.param.grad is not None:
if self.param.grad.device != self.param.data.device:
# TODO: offload?
raise RuntimeError(
'grad and param are on different device, grad {self.param.grad.device} vs. param {self.param.data.device}')
else:
self._gpu_grad = self.param.grad.data
self.param.grad = None
def reduce_scatter_callback(self, reduced_grad: torch.Tensor) -> None:
"""This function will be called in post-backward hook, so we cannot modify param.grad directly
:param reduced_grad: the reduced grad
:type reduced_grad: torch.Tensor
"""
# Make sure we store fp32 grad
if torch.is_floating_point(reduced_grad) and reduced_grad.dtype != torch.float:
reduced_grad.data = reduced_grad.data.to(torch.float)
if self._gpu_grad is None:
self._gpu_grad = reduced_grad.data
else:
self._gpu_grad += reduced_grad.data
# Optionally move gradients to CPU, typically used if one is running the optimizer on the CPU. Once the full
# backwards pass completes, we will set `.grad` to the CPU copy.
if self._cpu_offload:
self._cpu_grad.copy_(self._gpu_grad.data, non_blocking=True)
# Don't let this memory get reused until after the transfer.
self._gpu_grad.data.record_stream(torch.cuda.current_stream())
@torch.no_grad()
def write_back(self) -> None:
"""This function will be called in final backward hook
"""
if self._cpu_grad is not None:
assert self.param.device == torch.device(
'cpu'), f'Incorrect param device, expected CPU, got {self.param.device}'
self.param.grad.data = self._cpu_grad
elif self._gpu_grad is not None:
assert self.param.device == self._gpu_grad.device, f'Incorrect _gpu_grad device, param on {self.param.device} but _gpu_grad on {self._gpu_grad.device}'
self.param.grad.data = self._gpu_grad
else:
raise RuntimeError('No grad to write back')
# If using CPU offload, _cpu_grad will store the CPU tensor of _gpu_grad
# They should be released here
self._gpu_grad = None