From eb1b89908c9a4a35ff299b24dea7217e98d0ef37 Mon Sep 17 00:00:00 2001 From: Jiarui Fang Date: Thu, 21 Apr 2022 16:03:18 +0800 Subject: [PATCH] [refactor] moving InsertPostInitMethodToModuleSubClasses to utils. (#824) --- colossalai/utils/__init__.py | 4 +- colossalai/utils/model/init_context.py | 81 ++++++++++++++++++++++++ colossalai/zero/init_ctx/init_context.py | 77 +--------------------- 3 files changed, 85 insertions(+), 77 deletions(-) create mode 100644 colossalai/utils/model/init_context.py diff --git a/colossalai/utils/__init__.py b/colossalai/utils/__init__.py index 6e1720b3d..a35d3dcc3 100644 --- a/colossalai/utils/__init__.py +++ b/colossalai/utils/__init__.py @@ -11,6 +11,7 @@ from .memory import (report_memory_usage, colo_device_memory_used, colo_set_proc colo_device_memory_capacity, colo_set_cpu_memory_capacity, colo_get_cpu_memory_capacity) from .timer import MultiTimer, Timer from .tensor_detector import TensorDetector +from .model.init_context import InsertPostInitMethodToModuleSubClasses __all__ = [ 'checkpoint', 'free_port', 'print_rank_0', 'sync_model_param', 'is_dp_rank_0', 'is_tp_rank_0', @@ -20,5 +21,6 @@ __all__ = [ 'report_memory_usage', 'colo_device_memory_capacity', 'colo_device_memory_used', 'colo_set_process_memory_fraction', 'Timer', 'MultiTimer', 'multi_tensor_applier', 'DataParallelSampler', 'get_dataloader', 'switch_virtual_pipeline_parallel_rank', 'TensorDetector', 'load_checkpoint', 'save_checkpoint', - 'ensure_path_exists', 'disposable', 'colo_set_cpu_memory_capacity', 'colo_get_cpu_memory_capacity' + 'ensure_path_exists', 'disposable', 'colo_set_cpu_memory_capacity', 'colo_get_cpu_memory_capacity', + 'InsertPostInitMethodToModuleSubClasses' ] diff --git a/colossalai/utils/model/init_context.py b/colossalai/utils/model/init_context.py new file mode 100644 index 000000000..ced1365ec --- /dev/null +++ b/colossalai/utils/model/init_context.py @@ -0,0 +1,81 @@ +import torch +import functools +from typing import Optional + + +def _substitute_init_recursively(cls, func): + for subcls in cls.__subclasses__(): + _substitute_init_recursively(subcls, func) + func(subcls) + + +class InsertPostInitMethodToModuleSubClasses(object): + + def __init__(self, default_dtype: Optional[torch.dtype] = None): + self._old_default_dtype = None + self._default_dtype = default_dtype + + def __enter__(self): + r""" + Enter the context scope. + """ + if self._default_dtype is not None: + self._old_default_dtype = torch.get_default_dtype() + torch.set_default_dtype(self._default_dtype) + + def preprocess_after(f): + + @functools.wraps(f) + def wrapper(module: torch.nn.Module, *args, **kwargs): + f(module, *args, **kwargs) + self._post_init_method(module) + + return wrapper + + def _enable_class(cls): + cls._old_init = cls.__init__ + cls.__init__ = preprocess_after(cls.__init__) + + # The function is called during init subclass. + def _init_subclass(cls, **kwargs): + cls.__init__ = preprocess_after(cls.__init__) + + # Replace .__init__() for all existing subclasses of torch.nn.Module + # Excution self._post_init_method after the default init function. + _substitute_init_recursively(torch.nn.modules.module.Module, _enable_class) + + # holding on to the current __init__subclass__ for exit + torch.nn.modules.module.Module._old_init_subclass = (torch.nn.modules.module.Module.__init_subclass__) + # Replace .__init__() for future subclasses of torch.nn.Module + torch.nn.modules.module.Module.__init_subclass__ = classmethod(_init_subclass) + + self._pre_context_exec() + + def __exit__(self, exc_type, exc_value, traceback): + + if self._default_dtype is not None: + torch.set_default_dtype(self._old_default_dtype) + + def _disable_class(cls): + cls.__init__ = cls._old_init + + # Replace .__init__() for all existing subclasses of torch.nn.Module + _substitute_init_recursively(torch.nn.modules.module.Module, _disable_class) + + # Replace .__init__() for future subclasses of torch.nn.Module + torch.nn.modules.module.Module.__init_subclass__ = (torch.nn.modules.module.Module._old_init_subclass) + + self._post_context_exec() + # Now that we cleaned up the metaclass injection, raise the exception. + if exc_type is not None: + return False + + # To be implemented by inheriting classes + def _post_init_method(self, module): + pass + + def _pre_context_exec(self): + pass + + def _post_context_exec(self): + pass diff --git a/colossalai/zero/init_ctx/init_context.py b/colossalai/zero/init_ctx/init_context.py index bc41ec1e7..b26e82919 100644 --- a/colossalai/zero/init_ctx/init_context.py +++ b/colossalai/zero/init_ctx/init_context.py @@ -13,82 +13,7 @@ from colossalai.zero.shard_utils import BaseShardStrategy from colossalai.zero.sharded_model._utils import cast_tensor_to_fp16 from colossalai.zero.sharded_param import ShardedParamV2 from contextlib import AbstractContextManager - - -def _substitute_init_recursively(cls, func): - for subcls in cls.__subclasses__(): - _substitute_init_recursively(subcls, func) - func(subcls) - - -class InsertPostInitMethodToModuleSubClasses(object): - - def __init__(self, default_dtype: Optional[torch.dtype] = None): - self._old_default_dtype = None - self._default_dtype = default_dtype - - def __enter__(self): - r""" - Enter the context scope. - """ - if self._default_dtype is not None: - self._old_default_dtype = torch.get_default_dtype() - torch.set_default_dtype(self._default_dtype) - def preprocess_after(f): - - @functools.wraps(f) - def wrapper(module: torch.nn.Module, *args, **kwargs): - f(module, *args, **kwargs) - self._post_init_method(module) - - return wrapper - - def _enable_class(cls): - cls._old_init = cls.__init__ - cls.__init__ = preprocess_after(cls.__init__) - - # The function is called during init subclass. - def _init_subclass(cls, **kwargs): - cls.__init__ = preprocess_after(cls.__init__) - - # Replace .__init__() for all existing subclasses of torch.nn.Module - # Excution self._post_init_method after the default init function. - _substitute_init_recursively(torch.nn.modules.module.Module, _enable_class) - - # holding on to the current __init__subclass__ for exit - torch.nn.modules.module.Module._old_init_subclass = (torch.nn.modules.module.Module.__init_subclass__) - # Replace .__init__() for future subclasses of torch.nn.Module - torch.nn.modules.module.Module.__init_subclass__ = classmethod(_init_subclass) - - self._pre_context_exec() - - def __exit__(self, exc_type, exc_value, traceback): - - if self._default_dtype is not None: - torch.set_default_dtype(self._old_default_dtype) - def _disable_class(cls): - cls.__init__ = cls._old_init - - # Replace .__init__() for all existing subclasses of torch.nn.Module - _substitute_init_recursively(torch.nn.modules.module.Module, _disable_class) - - # Replace .__init__() for future subclasses of torch.nn.Module - torch.nn.modules.module.Module.__init_subclass__ = (torch.nn.modules.module.Module._old_init_subclass) - - self._post_context_exec() - # Now that we cleaned up the metaclass injection, raise the exception. - if exc_type is not None: - return False - - # To be implemented by inheriting classes - def _post_init_method(self, module): - pass - - def _pre_context_exec(self): - pass - - def _post_context_exec(self): - pass +from colossalai.utils import InsertPostInitMethodToModuleSubClasses class ZeroContextConfig(object):