mirror of https://github.com/hpcaitech/ColossalAI
146 lines
5.0 KiB
Python
146 lines
5.0 KiB
Python
from typing import Dict, List, Tuple
|
|
|
|
import torch
|
|
from torch.fx import Node
|
|
|
|
from colossalai.zero.legacy.gemini.tensor_utils import alloc_storage, free_storage
|
|
|
|
|
|
class Region:
|
|
"""
|
|
Region: A container owning a piece of contiguous nodes in the DNN computing graph.
|
|
|
|
Args:
|
|
r_id (int): the index of the region in the computing graph.
|
|
"""
|
|
|
|
def __init__(self, r_id: int = 0) -> None:
|
|
self.r_id: int = r_id
|
|
self.fp16_params: List[torch.nn.Parameter] = []
|
|
self.param_size: int = 0
|
|
self.shared_rid: int = self.r_id
|
|
|
|
self.param_num: int = 0
|
|
self.grad_num: int = 0
|
|
self.fp16_data = None
|
|
self.fp32_data = None
|
|
self.cpu_grad = None
|
|
self.temp_fp32_data = None
|
|
self.param_to_range: Dict[torch.nn.Parameter, Tuple[int, int]] = dict()
|
|
|
|
self.need_offload: bool = False
|
|
self.is_syn: bool = False
|
|
self.nodes: List[Node] = []
|
|
self.fwd_prefetch_region = None
|
|
self.bwd_prefetch_region = None
|
|
|
|
self.in_mem_pool_flag: bool = False
|
|
|
|
@property
|
|
def can_release(self) -> bool:
|
|
"""
|
|
Check if the region can be released.
|
|
"""
|
|
return self.grad_num == self.param_num
|
|
|
|
@property
|
|
def has_inf_or_nan(self) -> bool:
|
|
"""
|
|
Check if the grad of the region has inf or nan values on CUDA.
|
|
"""
|
|
return torch.isinf(self.fp16_data).any() | torch.isnan(self.fp16_data).any()
|
|
|
|
def init_param_data(self, pre_alloc_tensor: torch.Tensor = None):
|
|
"""
|
|
Map the parameters in the region to a contiguous memory space.
|
|
"""
|
|
|
|
self.fp16_data = torch.zeros(self.param_num, dtype=torch.half, device='cuda')
|
|
offset = 0
|
|
for param in self.fp16_params:
|
|
param.data = param.data.cuda()
|
|
p_num = param.data.numel()
|
|
self.fp16_data[offset:offset + p_num].copy_(param.data.flatten())
|
|
param.data = self.fp16_data[offset:offset + p_num].view(param.data.shape)
|
|
self.param_to_range[param] = (offset, offset + p_num)
|
|
offset += p_num
|
|
|
|
self.fp32_data = self.fp16_data.float().cpu().pin_memory()
|
|
free_storage(self.fp16_data)
|
|
if self.in_mem_pool_flag and pre_alloc_tensor is not None:
|
|
self.fp16_data = pre_alloc_tensor
|
|
|
|
def move_param_to_cuda(self):
|
|
"""
|
|
Move parameters from CPU to GPU.
|
|
It first moves float32 parameters to GPU and
|
|
then transforms float32 parameters to half-precision on the GPU.
|
|
The reason is that the performance of precision conversion on the CPU
|
|
is much slower than the data transfer overhead.
|
|
"""
|
|
|
|
self.temp_fp32_data.copy_(self.fp32_data, non_blocking=True)
|
|
self.temp_fp32_data.record_stream(torch.cuda.current_stream())
|
|
if not self.in_mem_pool_flag:
|
|
alloc_storage(self.fp16_data)
|
|
self.fp16_data[:self.param_num].copy_(self.temp_fp32_data)
|
|
self.fp16_data.record_stream(torch.cuda.current_stream())
|
|
|
|
self.__update_params_ptr()
|
|
|
|
def move_grad_to_cpu(self):
|
|
"""
|
|
Move gradients from GPU to CPU.
|
|
"""
|
|
|
|
self.cpu_grad = torch.empty(self.param_num, dtype=torch.half, pin_memory=True)
|
|
self.cpu_grad.copy_(self.fp16_data[:self.param_num], non_blocking=True)
|
|
self.fp16_data.record_stream(torch.cuda.current_stream())
|
|
if not self.in_mem_pool_flag:
|
|
self.free_cuda_data()
|
|
|
|
self.grad_num = 0
|
|
|
|
def free_cuda_data(self):
|
|
free_storage(self.fp16_data)
|
|
|
|
# torch.cuda.empty_cache()
|
|
|
|
def copy_grad_to_region_slice(self, param: torch.nn.Parameter, data_slice: torch.Tensor) -> None:
|
|
"""
|
|
Copy data slice to the memory space indexed by the input tensor in the region.
|
|
|
|
Args:
|
|
param (torch.nn.Parameter): the param used to retrieve meta information
|
|
data_slice (torch.Tensor): the tensor to be copied to the region
|
|
"""
|
|
|
|
begin, end = self.param_to_range[param]
|
|
self.fp16_data[begin:end].copy_(data_slice.data.flatten())
|
|
param.data = self.fp16_data[begin:end].view(param.data.shape)
|
|
|
|
self.grad_num += data_slice.numel()
|
|
|
|
def split(self, cut_node_idx: int, cut_param_idx: int):
|
|
"""
|
|
Split the region into two and return the latter.
|
|
"""
|
|
new_reg = Region(r_id=self.r_id + 1)
|
|
new_reg.nodes = self.nodes[cut_node_idx:]
|
|
new_reg.fp16_params = self.fp16_params[cut_param_idx:]
|
|
for p in new_reg.fp16_params:
|
|
new_reg.param_size += p.data.numel() * p.data.element_size()
|
|
new_reg.param_num += p.data.numel()
|
|
|
|
self.nodes = self.nodes[:cut_node_idx]
|
|
self.fp16_params = self.fp16_params[:cut_param_idx]
|
|
self.param_size -= new_reg.param_size
|
|
self.param_num -= new_reg.param_num
|
|
|
|
return new_reg
|
|
|
|
def __update_params_ptr(self) -> None:
|
|
for param in self.fp16_params:
|
|
begin, end = self.param_to_range[param]
|
|
param.data = self.fp16_data[begin:end].view(param.data.shape)
|