From de0468c7a84ba65e735a5ea6285619989e4116e4 Mon Sep 17 00:00:00 2001 From: Jiarui Fang Date: Mon, 7 Mar 2022 16:14:40 +0800 Subject: [PATCH] [zero] zero init context (#321) * add zero init context * add more flags for zero init context fix bug of repeated converting param to ShardedParamV2 * polish code --- colossalai/zero/init_ctx/__init__.py | 3 + colossalai/zero/init_ctx/init_context.py | 128 ++++++++++++++++++ .../zero/sharded_param/sharded_param.py | 9 +- .../test_init_context.py | 38 ++++++ 4 files changed, 173 insertions(+), 5 deletions(-) create mode 100644 colossalai/zero/init_ctx/__init__.py create mode 100644 colossalai/zero/init_ctx/init_context.py create mode 100644 tests/test_zero_data_parallel/test_init_context.py diff --git a/colossalai/zero/init_ctx/__init__.py b/colossalai/zero/init_ctx/__init__.py new file mode 100644 index 000000000..804b36b02 --- /dev/null +++ b/colossalai/zero/init_ctx/__init__.py @@ -0,0 +1,3 @@ +from .init_context import ZeroInitContext + +__all__ = ['ZeroInitContext'] diff --git a/colossalai/zero/init_ctx/init_context.py b/colossalai/zero/init_ctx/init_context.py new file mode 100644 index 000000000..d7bd82c27 --- /dev/null +++ b/colossalai/zero/init_ctx/init_context.py @@ -0,0 +1,128 @@ +import functools +from colossalai.utils.cuda import get_current_device +import torch +from colossalai.zero.shard_utils import BaseShardStrategy +from colossalai.zero.sharded_param import ShardedParamV2 + + +# Inserts _post_init_method at the end of init method +# for all sub classes of torch.nn.Module +class InsertPostInitMethodToModuleSubClasses(object): + + def __init__(self): + pass + + def __enter__(self): + r""" + Enter the context scope. + """ + + def preprocess_after(f): + + @functools.wraps(f) + def wrapper(module: torch.nn.Module, *args, **kwargs): + f(module, *args, **kwargs) + self._post_init_method(module) + + return wrapper + + def _enable_class(cls): + cls._old_init = cls.__init__ + cls.__init__ = preprocess_after(cls.__init__) + + # The function is called during init subclass. + def _init_subclass(cls, **kwargs): + cls.__init__ = preprocess_after(cls.__init__) + + # Replace .__init__() for all existing subclasses of torch.nn.Module + # Excution self._post_init_method after the default init function. + for subclass in torch.nn.modules.module.Module.__subclasses__(): + _enable_class(subclass) + + # holding on to the current __init__subclass__ for exit + torch.nn.modules.module.Module._old_init_subclass = (torch.nn.modules.module.Module.__init_subclass__) + # Replace .__init__() for future subclasses of torch.nn.Module + torch.nn.modules.module.Module.__init_subclass__ = classmethod(_init_subclass) + + self._pre_context_exec() + + def __exit__(self, exc_type, exc_value, traceback): + + def _disable_class(cls): + cls.__init__ = cls._old_init + + # Replace .__init__() for all existing subclasses of torch.nn.Module + for subclass in torch.nn.modules.module.Module.__subclasses__(): + _disable_class(subclass) + + # Replace .__init__() for future subclasses of torch.nn.Module + torch.nn.modules.module.Module.__init_subclass__ = (torch.nn.modules.module.Module._old_init_subclass) + + self._post_context_exec() + # Now that we cleaned up the metaclass injection, raise the exception. + if exc_type is not None: + return False + + # To be implemented by inheriting classes + def _post_init_method(self, module): + pass + + def _pre_context_exec(self): + pass + + def _post_context_exec(self): + pass + + +class ZeroInitContext(InsertPostInitMethodToModuleSubClasses): + """ + A context to initialize model. + 1. Convert the model to fp16. + 2. The paramaters of the module are adapted to type ShardedParameter. + 3. Shard the param and grad according to flags. + """ + + def __init__( + self, + convert_fp16: bool, + convert_cuda: bool, + shard_strategy: BaseShardStrategy, + shard_param: bool = False, + shard_grad: bool = False, + ): + super().__init__() + self.convert_fp16 = convert_fp16 + self.convert_cuda = convert_cuda + self.shard_param = shard_param + self.shard_grad = shard_grad + self.shard_strategy = shard_strategy + + def _post_context_exec(self): + """The callback function when the context exits. + """ + pass + + def _post_init_method(self, module): + r"""The function to call at the end of the constructor of each nn.Module. + """ + for param in module.parameters(): + # avoid adapting a param to ShardedParam twice + if hasattr(param, 'ca_attr'): + continue + + if self.convert_cuda: + target_device = get_current_device() + else: + target_device = param.data.device + + # convert to fp16 and cuda if necessary + if self.convert_fp16: + param.data = param.data.to(torch.half).to(target_device) + if param.grad is not None: + param.grad = param.grad.to(torch.half).to(target_device) + + param.ca_attr = ShardedParamV2(param) + if self.shard_param: + self.shard_strategy.shard(tensor_list=[param.ca_attr._data_sharded_tensor]) + if param.ca_attr.grad and self.shard_grad: + self.shard_strategy.shard(tensor_list=[param.ca_attr._grad_sharded_tensor]) diff --git a/colossalai/zero/sharded_param/sharded_param.py b/colossalai/zero/sharded_param/sharded_param.py index 61e9d9d32..1358bcc3a 100644 --- a/colossalai/zero/sharded_param/sharded_param.py +++ b/colossalai/zero/sharded_param/sharded_param.py @@ -1,5 +1,3 @@ -from typing import Optional, Tuple, Union - import numpy import torch import torch.distributed as dist @@ -8,8 +6,6 @@ from colossalai.core import global_context as gpc from colossalai.zero.sharded_model._zero3_utils import get_shard from colossalai.zero.sharded_param import ShardedTensor from typing import Union, Tuple, Optional -import numpy - class ShardedParamV2(object): @@ -35,7 +31,10 @@ class ShardedParamV2(object): @property def grad(self): - return self._grad_sharded_tensor.payload + if self._grad_sharded_tensor: + return self._grad_sharded_tensor.payload + else: + return None @grad.setter def grad(self, t: torch.Tensor): diff --git a/tests/test_zero_data_parallel/test_init_context.py b/tests/test_zero_data_parallel/test_init_context.py new file mode 100644 index 000000000..cf038844c --- /dev/null +++ b/tests/test_zero_data_parallel/test_init_context.py @@ -0,0 +1,38 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +from functools import partial + +import colossalai +import pytest +import torch +import torch.multiprocessing as mp +from colossalai.zero.shard_utils.tensor_shard_strategy import TensorShardStrategy +from colossalai.zero.init_ctx import ZeroInitContext +from common import CONFIG, Net +from colossalai.utils import free_port + + +def run_dist(rank, world_size, port): + colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + + with ZeroInitContext(convert_fp16=True, convert_cuda=True, shard_strategy=TensorShardStrategy(), shard_param=True): + # Note Net(checkpoint=True).cuda() moving to cuda is useless + model = Net(checkpoint=True) + + for param in model.parameters(): + assert hasattr(param, 'ca_attr') + assert param.ca_attr.data.dtype == torch.half + assert param.ca_attr._data_sharded_tensor.is_sharded + assert param.ca_attr.data.device.type == 'cuda' + + +@pytest.mark.dist +def test_zero_init_context(): + world_size = 2 + run_func = partial(run_dist, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_zero_init_context()