ColossalAI/colossalai/auto_parallel/offload/region.py

146 lines
5.0 KiB
Python

from typing import Dict, List, Tuple
import torch
from torch.fx import Node
from colossalai.zero.legacy.gemini.tensor_utils import alloc_storage, free_storage
class Region:
"""
Region: A container owning a piece of contiguous nodes in the DNN computing graph.
Args:
r_id (int): the index of the region in the computing graph.
"""
def __init__(self, r_id: int = 0) -> None:
self.r_id: int = r_id
self.fp16_params: List[torch.nn.Parameter] = []
self.param_size: int = 0
self.shared_rid: int = self.r_id
self.param_num: int = 0
self.grad_num: int = 0
self.fp16_data = None
self.fp32_data = None
self.cpu_grad = None
self.temp_fp32_data = None
self.param_to_range: Dict[torch.nn.Parameter, Tuple[int, int]] = dict()
self.need_offload: bool = False
self.is_syn: bool = False
self.nodes: List[Node] = []
self.fwd_prefetch_region = None
self.bwd_prefetch_region = None
self.in_mem_pool_flag: bool = False
@property
def can_release(self) -> bool:
"""
Check if the region can be released.
"""
return self.grad_num == self.param_num
@property
def has_inf_or_nan(self) -> bool:
"""
Check if the grad of the region has inf or nan values on CUDA.
"""
return torch.isinf(self.fp16_data).any() | torch.isnan(self.fp16_data).any()
def init_param_data(self, pre_alloc_tensor: torch.Tensor = None):
"""
Map the parameters in the region to a contiguous memory space.
"""
self.fp16_data = torch.zeros(self.param_num, dtype=torch.half, device="cuda")
offset = 0
for param in self.fp16_params:
param.data = param.data.cuda()
p_num = param.data.numel()
self.fp16_data[offset : offset + p_num].copy_(param.data.flatten())
param.data = self.fp16_data[offset : offset + p_num].view(param.data.shape)
self.param_to_range[param] = (offset, offset + p_num)
offset += p_num
self.fp32_data = self.fp16_data.float().cpu().pin_memory()
free_storage(self.fp16_data)
if self.in_mem_pool_flag and pre_alloc_tensor is not None:
self.fp16_data = pre_alloc_tensor
def move_param_to_cuda(self):
"""
Move parameters from CPU to GPU.
It first moves float32 parameters to GPU and
then transforms float32 parameters to half-precision on the GPU.
The reason is that the performance of precision conversion on the CPU
is much slower than the data transfer overhead.
"""
self.temp_fp32_data.copy_(self.fp32_data, non_blocking=True)
self.temp_fp32_data.record_stream(torch.cuda.current_stream())
if not self.in_mem_pool_flag:
alloc_storage(self.fp16_data)
self.fp16_data[: self.param_num].copy_(self.temp_fp32_data)
self.fp16_data.record_stream(torch.cuda.current_stream())
self.__update_params_ptr()
def move_grad_to_cpu(self):
"""
Move gradients from GPU to CPU.
"""
self.cpu_grad = torch.empty(self.param_num, dtype=torch.half, pin_memory=True)
self.cpu_grad.copy_(self.fp16_data[: self.param_num], non_blocking=True)
self.fp16_data.record_stream(torch.cuda.current_stream())
if not self.in_mem_pool_flag:
self.free_cuda_data()
self.grad_num = 0
def free_cuda_data(self):
free_storage(self.fp16_data)
# torch.cuda.empty_cache()
def copy_grad_to_region_slice(self, param: torch.nn.Parameter, data_slice: torch.Tensor) -> None:
"""
Copy data slice to the memory space indexed by the input tensor in the region.
Args:
param (torch.nn.Parameter): the param used to retrieve meta information
data_slice (torch.Tensor): the tensor to be copied to the region
"""
begin, end = self.param_to_range[param]
self.fp16_data[begin:end].copy_(data_slice.data.flatten())
param.data = self.fp16_data[begin:end].view(param.data.shape)
self.grad_num += data_slice.numel()
def split(self, cut_node_idx: int, cut_param_idx: int):
"""
Split the region into two and return the latter.
"""
new_reg = Region(r_id=self.r_id + 1)
new_reg.nodes = self.nodes[cut_node_idx:]
new_reg.fp16_params = self.fp16_params[cut_param_idx:]
for p in new_reg.fp16_params:
new_reg.param_size += p.data.numel() * p.data.element_size()
new_reg.param_num += p.data.numel()
self.nodes = self.nodes[:cut_node_idx]
self.fp16_params = self.fp16_params[:cut_param_idx]
self.param_size -= new_reg.param_size
self.param_num -= new_reg.param_num
return new_reg
def __update_params_ptr(self) -> None:
for param in self.fp16_params:
begin, end = self.param_to_range[param]
param.data = self.fp16_data[begin:end].view(param.data.shape)