[zero] zero init context (#321)

* add zero init context

* add more flags for zero init context
fix bug of repeated converting param to ShardedParamV2

* polish code
pull/394/head
Jiarui Fang 2022-03-07 16:14:40 +08:00 committed by Frank Lee
parent 73bff11288
commit de0468c7a8
4 changed files with 173 additions and 5 deletions

View File

@ -0,0 +1,3 @@
from .init_context import ZeroInitContext
__all__ = ['ZeroInitContext']

View File

@ -0,0 +1,128 @@
import functools
from colossalai.utils.cuda import get_current_device
import torch
from colossalai.zero.shard_utils import BaseShardStrategy
from colossalai.zero.sharded_param import ShardedParamV2
# Inserts _post_init_method at the end of init method
# for all sub classes of torch.nn.Module
class InsertPostInitMethodToModuleSubClasses(object):
def __init__(self):
pass
def __enter__(self):
r"""
Enter the context scope.
"""
def preprocess_after(f):
@functools.wraps(f)
def wrapper(module: torch.nn.Module, *args, **kwargs):
f(module, *args, **kwargs)
self._post_init_method(module)
return wrapper
def _enable_class(cls):
cls._old_init = cls.__init__
cls.__init__ = preprocess_after(cls.__init__)
# The function is called during init subclass.
def _init_subclass(cls, **kwargs):
cls.__init__ = preprocess_after(cls.__init__)
# Replace .__init__() for all existing subclasses of torch.nn.Module
# Excution self._post_init_method after the default init function.
for subclass in torch.nn.modules.module.Module.__subclasses__():
_enable_class(subclass)
# holding on to the current __init__subclass__ for exit
torch.nn.modules.module.Module._old_init_subclass = (torch.nn.modules.module.Module.__init_subclass__)
# Replace .__init__() for future subclasses of torch.nn.Module
torch.nn.modules.module.Module.__init_subclass__ = classmethod(_init_subclass)
self._pre_context_exec()
def __exit__(self, exc_type, exc_value, traceback):
def _disable_class(cls):
cls.__init__ = cls._old_init
# Replace .__init__() for all existing subclasses of torch.nn.Module
for subclass in torch.nn.modules.module.Module.__subclasses__():
_disable_class(subclass)
# Replace .__init__() for future subclasses of torch.nn.Module
torch.nn.modules.module.Module.__init_subclass__ = (torch.nn.modules.module.Module._old_init_subclass)
self._post_context_exec()
# Now that we cleaned up the metaclass injection, raise the exception.
if exc_type is not None:
return False
# To be implemented by inheriting classes
def _post_init_method(self, module):
pass
def _pre_context_exec(self):
pass
def _post_context_exec(self):
pass
class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
"""
A context to initialize model.
1. Convert the model to fp16.
2. The paramaters of the module are adapted to type ShardedParameter.
3. Shard the param and grad according to flags.
"""
def __init__(
self,
convert_fp16: bool,
convert_cuda: bool,
shard_strategy: BaseShardStrategy,
shard_param: bool = False,
shard_grad: bool = False,
):
super().__init__()
self.convert_fp16 = convert_fp16
self.convert_cuda = convert_cuda
self.shard_param = shard_param
self.shard_grad = shard_grad
self.shard_strategy = shard_strategy
def _post_context_exec(self):
"""The callback function when the context exits.
"""
pass
def _post_init_method(self, module):
r"""The function to call at the end of the constructor of each nn.Module.
"""
for param in module.parameters():
# avoid adapting a param to ShardedParam twice
if hasattr(param, 'ca_attr'):
continue
if self.convert_cuda:
target_device = get_current_device()
else:
target_device = param.data.device
# convert to fp16 and cuda if necessary
if self.convert_fp16:
param.data = param.data.to(torch.half).to(target_device)
if param.grad is not None:
param.grad = param.grad.to(torch.half).to(target_device)
param.ca_attr = ShardedParamV2(param)
if self.shard_param:
self.shard_strategy.shard(tensor_list=[param.ca_attr._data_sharded_tensor])
if param.ca_attr.grad and self.shard_grad:
self.shard_strategy.shard(tensor_list=[param.ca_attr._grad_sharded_tensor])

View File

@ -1,5 +1,3 @@
from typing import Optional, Tuple, Union
import numpy
import torch
import torch.distributed as dist
@ -8,8 +6,6 @@ from colossalai.core import global_context as gpc
from colossalai.zero.sharded_model._zero3_utils import get_shard
from colossalai.zero.sharded_param import ShardedTensor
from typing import Union, Tuple, Optional
import numpy
class ShardedParamV2(object):
@ -35,7 +31,10 @@ class ShardedParamV2(object):
@property
def grad(self):
return self._grad_sharded_tensor.payload
if self._grad_sharded_tensor:
return self._grad_sharded_tensor.payload
else:
return None
@grad.setter
def grad(self, t: torch.Tensor):

View File

@ -0,0 +1,38 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from functools import partial
import colossalai
import pytest
import torch
import torch.multiprocessing as mp
from colossalai.zero.shard_utils.tensor_shard_strategy import TensorShardStrategy
from colossalai.zero.init_ctx import ZeroInitContext
from common import CONFIG, Net
from colossalai.utils import free_port
def run_dist(rank, world_size, port):
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
with ZeroInitContext(convert_fp16=True, convert_cuda=True, shard_strategy=TensorShardStrategy(), shard_param=True):
# Note Net(checkpoint=True).cuda() moving to cuda is useless
model = Net(checkpoint=True)
for param in model.parameters():
assert hasattr(param, 'ca_attr')
assert param.ca_attr.data.dtype == torch.half
assert param.ca_attr._data_sharded_tensor.is_sharded
assert param.ca_attr.data.device.type == 'cuda'
@pytest.mark.dist
def test_zero_init_context():
world_size = 2
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_zero_init_context()