Init Conext supports lazy allocate model memory (#842)

pull/832/head
Jiarui Fang 2022-04-22 18:03:35 +08:00 committed by GitHub
parent 4575a3298b
commit 8789850eea
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 112 additions and 13 deletions

View File

@ -43,9 +43,10 @@ class ColoTensor(object):
torch_tensor=tensor if save_payload else torch.empty(0))
return colo_t
def del_torch_tensor(self) -> None:
self._size = (0,)
self._torch_tensor = torch.empty(self._size)
def del_torch_tensor(self, save_shape=False) -> None:
if save_shape:
self._size = (0,)
self._torch_tensor = torch.empty((0,))
def torch_tensor(self) -> torch.Tensor:
if self._torch_tensor.numel() == 0:

View File

@ -11,16 +11,47 @@ from .memory import (report_memory_usage, colo_device_memory_used, colo_set_proc
colo_device_memory_capacity, colo_set_cpu_memory_capacity, colo_get_cpu_memory_capacity)
from .timer import MultiTimer, Timer
from .tensor_detector import TensorDetector
from .model.init_context import InsertPostInitMethodToModuleSubClasses
from .model.utils import InsertPostInitMethodToModuleSubClasses
from .model.colo_init_context import ColoInitContext
__all__ = [
'checkpoint', 'free_port', 'print_rank_0', 'sync_model_param', 'is_dp_rank_0', 'is_tp_rank_0',
'is_no_pp_or_last_stage', 'is_using_ddp', 'is_using_pp', 'is_using_sequence', 'conditional_context',
'is_model_parallel_parameter', 'clip_grad_norm_fp32', 'count_zeros_fp32', 'copy_tensor_parallel_attributes',
'param_is_not_tensor_parallel_duplicate', 'get_current_device', 'synchronize', 'empty_cache', 'set_to_cuda',
'report_memory_usage', 'colo_device_memory_capacity', 'colo_device_memory_used', 'colo_set_process_memory_fraction',
'Timer', 'MultiTimer', 'multi_tensor_applier', 'DataParallelSampler', 'get_dataloader',
'switch_virtual_pipeline_parallel_rank', 'TensorDetector', 'load_checkpoint', 'save_checkpoint',
'ensure_path_exists', 'disposable', 'colo_set_cpu_memory_capacity', 'colo_get_cpu_memory_capacity',
'InsertPostInitMethodToModuleSubClasses'
'checkpoint',
'free_port',
'print_rank_0',
'sync_model_param',
'is_dp_rank_0',
'is_tp_rank_0',
'is_no_pp_or_last_stage',
'is_using_ddp',
'is_using_pp',
'is_using_sequence',
'conditional_context',
'is_model_parallel_parameter',
'clip_grad_norm_fp32',
'count_zeros_fp32',
'copy_tensor_parallel_attributes',
'param_is_not_tensor_parallel_duplicate',
'get_current_device',
'synchronize',
'empty_cache',
'set_to_cuda',
'report_memory_usage',
'colo_device_memory_capacity',
'colo_device_memory_used',
'colo_set_process_memory_fraction',
'Timer',
'MultiTimer',
'multi_tensor_applier',
'DataParallelSampler',
'get_dataloader',
'switch_virtual_pipeline_parallel_rank',
'TensorDetector',
'load_checkpoint',
'save_checkpoint',
'ensure_path_exists',
'disposable',
'colo_set_cpu_memory_capacity',
'colo_get_cpu_memory_capacity',
'InsertPostInitMethodToModuleSubClasses',
'ColoInitContext',
]

View File

@ -0,0 +1,40 @@
from .utils import InsertPostInitMethodToModuleSubClasses
import torch
# from colossalai.logging import get_dist_logger
from colossalai.tensor import ColoTensor
# _orig_torch_empty = torch.empty
class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
def __init__(self, lazy_memory_allocate=False):
super().__init__()
self._lazy_memory_allocate = lazy_memory_allocate
def _pre_context_exec(self):
"""
The Callback function when entering the context
"""
pass
def _post_context_exec(self):
"""The callback function when exiting context.
"""
pass
def _post_init_method(self, module: torch.nn.Module):
"""
The function to call at the end of the constructor of each module.
FIXME(fjr) The module may be passed to this function multiple times?
"""
name_list = []
for name, param in module.named_parameters():
if isinstance(param, ColoTensor):
continue
name_list.append((name, param))
save_torch_payload = True if not self._lazy_memory_allocate else False
for name, param in name_list:
delattr(module, name)
setattr(module, name, ColoTensor.init_from_torch_tensor(tensor=param.data, save_payload=save_torch_payload))

View File

@ -0,0 +1,27 @@
from colossalai.utils import ColoInitContext
from numpy import allclose, require
import torch
from colossalai.tensor import ColoTensor
from copy import deepcopy
def test_linear():
in_dim = 4
out_dim = 5
with ColoInitContext(lazy_memory_allocate=True) as ctx:
fc = torch.nn.Linear(in_dim, out_dim, bias=True)
print(fc.weight.numel())
print(fc.bias.numel())
# lazy_memory_allocate=True, no payload is maintained
assert fc.weight._torch_tensor.numel() == 0
fc.weight.torch_tensor()
assert fc.weight._torch_tensor.numel() == in_dim * out_dim
if __name__ == '__main__':
test_linear()