mirror of https://github.com/hpcaitech/ColossalAI
38 lines
1.5 KiB
Python
38 lines
1.5 KiB
Python
from colossalai.utils.cuda import get_current_device
|
|
from .utils import InsertPostInitMethodToModuleSubClasses
|
|
import torch
|
|
# from colossalai.logging import get_dist_logger
|
|
from colossalai.tensor import ColoTensor
|
|
|
|
# _orig_torch_empty = torch.empty
|
|
|
|
|
|
class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
|
|
|
|
def __init__(self, lazy_memory_allocate: bool = False, device: torch.device = torch.device('cpu')):
|
|
"""
|
|
Args:
|
|
lazy_memory_allocate (bool, optional): whether to allocate memory for the parameter tensors. Defaults to False.
|
|
device (torch.device, optional): the device parameters initialized are resident on. Defaults to torch.device('cpu').
|
|
"""
|
|
super().__init__()
|
|
self._lazy_memory_allocate = lazy_memory_allocate
|
|
self._device = device
|
|
|
|
def _post_init_method(self, module: torch.nn.Module, *args, **kwargs):
|
|
"""
|
|
The function to call at the end of the constructor of each module.
|
|
FIXME(fjr) The module may be passed to this function multiple times?
|
|
"""
|
|
name_list = []
|
|
for name, param in module.named_parameters():
|
|
if isinstance(param, ColoTensor):
|
|
continue
|
|
name_list.append((name, param))
|
|
|
|
save_torch_payload = True if not self._lazy_memory_allocate else False
|
|
for name, param in name_list:
|
|
delattr(module, name)
|
|
setattr(module, name,
|
|
ColoTensor.init_from_torch_tensor(tensor=param.to(self._device), save_payload=save_torch_payload))
|