mirror of https://github.com/hpcaitech/ColossalAI
Init Conext supports lazy allocate model memory (#842)
parent
4575a3298b
commit
8789850eea
|
@ -43,9 +43,10 @@ class ColoTensor(object):
|
||||||
torch_tensor=tensor if save_payload else torch.empty(0))
|
torch_tensor=tensor if save_payload else torch.empty(0))
|
||||||
return colo_t
|
return colo_t
|
||||||
|
|
||||||
def del_torch_tensor(self) -> None:
|
def del_torch_tensor(self, save_shape=False) -> None:
|
||||||
self._size = (0,)
|
if save_shape:
|
||||||
self._torch_tensor = torch.empty(self._size)
|
self._size = (0,)
|
||||||
|
self._torch_tensor = torch.empty((0,))
|
||||||
|
|
||||||
def torch_tensor(self) -> torch.Tensor:
|
def torch_tensor(self) -> torch.Tensor:
|
||||||
if self._torch_tensor.numel() == 0:
|
if self._torch_tensor.numel() == 0:
|
||||||
|
|
|
@ -11,16 +11,47 @@ from .memory import (report_memory_usage, colo_device_memory_used, colo_set_proc
|
||||||
colo_device_memory_capacity, colo_set_cpu_memory_capacity, colo_get_cpu_memory_capacity)
|
colo_device_memory_capacity, colo_set_cpu_memory_capacity, colo_get_cpu_memory_capacity)
|
||||||
from .timer import MultiTimer, Timer
|
from .timer import MultiTimer, Timer
|
||||||
from .tensor_detector import TensorDetector
|
from .tensor_detector import TensorDetector
|
||||||
from .model.init_context import InsertPostInitMethodToModuleSubClasses
|
from .model.utils import InsertPostInitMethodToModuleSubClasses
|
||||||
|
from .model.colo_init_context import ColoInitContext
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'checkpoint', 'free_port', 'print_rank_0', 'sync_model_param', 'is_dp_rank_0', 'is_tp_rank_0',
|
'checkpoint',
|
||||||
'is_no_pp_or_last_stage', 'is_using_ddp', 'is_using_pp', 'is_using_sequence', 'conditional_context',
|
'free_port',
|
||||||
'is_model_parallel_parameter', 'clip_grad_norm_fp32', 'count_zeros_fp32', 'copy_tensor_parallel_attributes',
|
'print_rank_0',
|
||||||
'param_is_not_tensor_parallel_duplicate', 'get_current_device', 'synchronize', 'empty_cache', 'set_to_cuda',
|
'sync_model_param',
|
||||||
'report_memory_usage', 'colo_device_memory_capacity', 'colo_device_memory_used', 'colo_set_process_memory_fraction',
|
'is_dp_rank_0',
|
||||||
'Timer', 'MultiTimer', 'multi_tensor_applier', 'DataParallelSampler', 'get_dataloader',
|
'is_tp_rank_0',
|
||||||
'switch_virtual_pipeline_parallel_rank', 'TensorDetector', 'load_checkpoint', 'save_checkpoint',
|
'is_no_pp_or_last_stage',
|
||||||
'ensure_path_exists', 'disposable', 'colo_set_cpu_memory_capacity', 'colo_get_cpu_memory_capacity',
|
'is_using_ddp',
|
||||||
'InsertPostInitMethodToModuleSubClasses'
|
'is_using_pp',
|
||||||
|
'is_using_sequence',
|
||||||
|
'conditional_context',
|
||||||
|
'is_model_parallel_parameter',
|
||||||
|
'clip_grad_norm_fp32',
|
||||||
|
'count_zeros_fp32',
|
||||||
|
'copy_tensor_parallel_attributes',
|
||||||
|
'param_is_not_tensor_parallel_duplicate',
|
||||||
|
'get_current_device',
|
||||||
|
'synchronize',
|
||||||
|
'empty_cache',
|
||||||
|
'set_to_cuda',
|
||||||
|
'report_memory_usage',
|
||||||
|
'colo_device_memory_capacity',
|
||||||
|
'colo_device_memory_used',
|
||||||
|
'colo_set_process_memory_fraction',
|
||||||
|
'Timer',
|
||||||
|
'MultiTimer',
|
||||||
|
'multi_tensor_applier',
|
||||||
|
'DataParallelSampler',
|
||||||
|
'get_dataloader',
|
||||||
|
'switch_virtual_pipeline_parallel_rank',
|
||||||
|
'TensorDetector',
|
||||||
|
'load_checkpoint',
|
||||||
|
'save_checkpoint',
|
||||||
|
'ensure_path_exists',
|
||||||
|
'disposable',
|
||||||
|
'colo_set_cpu_memory_capacity',
|
||||||
|
'colo_get_cpu_memory_capacity',
|
||||||
|
'InsertPostInitMethodToModuleSubClasses',
|
||||||
|
'ColoInitContext',
|
||||||
]
|
]
|
||||||
|
|
|
@ -0,0 +1,40 @@
|
||||||
|
from .utils import InsertPostInitMethodToModuleSubClasses
|
||||||
|
import torch
|
||||||
|
# from colossalai.logging import get_dist_logger
|
||||||
|
from colossalai.tensor import ColoTensor
|
||||||
|
|
||||||
|
# _orig_torch_empty = torch.empty
|
||||||
|
|
||||||
|
|
||||||
|
class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
|
||||||
|
|
||||||
|
def __init__(self, lazy_memory_allocate=False):
|
||||||
|
super().__init__()
|
||||||
|
self._lazy_memory_allocate = lazy_memory_allocate
|
||||||
|
|
||||||
|
def _pre_context_exec(self):
|
||||||
|
"""
|
||||||
|
The Callback function when entering the context
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _post_context_exec(self):
|
||||||
|
"""The callback function when exiting context.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _post_init_method(self, module: torch.nn.Module):
|
||||||
|
"""
|
||||||
|
The function to call at the end of the constructor of each module.
|
||||||
|
FIXME(fjr) The module may be passed to this function multiple times?
|
||||||
|
"""
|
||||||
|
name_list = []
|
||||||
|
for name, param in module.named_parameters():
|
||||||
|
if isinstance(param, ColoTensor):
|
||||||
|
continue
|
||||||
|
name_list.append((name, param))
|
||||||
|
|
||||||
|
save_torch_payload = True if not self._lazy_memory_allocate else False
|
||||||
|
for name, param in name_list:
|
||||||
|
delattr(module, name)
|
||||||
|
setattr(module, name, ColoTensor.init_from_torch_tensor(tensor=param.data, save_payload=save_torch_payload))
|
|
@ -0,0 +1,27 @@
|
||||||
|
from colossalai.utils import ColoInitContext
|
||||||
|
|
||||||
|
from numpy import allclose, require
|
||||||
|
import torch
|
||||||
|
from colossalai.tensor import ColoTensor
|
||||||
|
from copy import deepcopy
|
||||||
|
|
||||||
|
|
||||||
|
def test_linear():
|
||||||
|
in_dim = 4
|
||||||
|
out_dim = 5
|
||||||
|
|
||||||
|
with ColoInitContext(lazy_memory_allocate=True) as ctx:
|
||||||
|
fc = torch.nn.Linear(in_dim, out_dim, bias=True)
|
||||||
|
|
||||||
|
print(fc.weight.numel())
|
||||||
|
print(fc.bias.numel())
|
||||||
|
|
||||||
|
# lazy_memory_allocate=True, no payload is maintained
|
||||||
|
assert fc.weight._torch_tensor.numel() == 0
|
||||||
|
|
||||||
|
fc.weight.torch_tensor()
|
||||||
|
assert fc.weight._torch_tensor.numel() == in_dim * out_dim
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test_linear()
|
Loading…
Reference in New Issue