diff --git a/colossalai/engine/ophooks/__init__.py b/colossalai/engine/ophooks/__init__.py index ee2a8ac44..8a1071e38 100644 --- a/colossalai/engine/ophooks/__init__.py +++ b/colossalai/engine/ophooks/__init__.py @@ -64,18 +64,13 @@ class PostBackwardFunction(torch.autograd.Function): def register_ophooks_recursively(module: torch.nn.Module, ophook_list: List[BaseOpHook] = None, name: str = ""): r"""Recursilvely register pre/post hooks for all submodules in the module in FWD and BWD.""" assert isinstance(module, torch.nn.Module) - has_children = False + + # Add hooks for submodules for child_name, child in module.named_children(): register_ophooks_recursively(child, ophook_list, name + child_name) - has_children = True - # Early return on modules with no parameters or buffers that - # are not in their children. - if (len(list(module.named_parameters(recurse=False))) == 0 and len(list(module.named_buffers(recurse=False))) == 0): - return - - # return if the module has not childern. - if has_children: + # Early return on modules with no parameters. + if len(list(module.parameters(recurse=False))) == 0: return if ophook_list is not None: diff --git a/colossalai/engine/ophooks/zero_hook.py b/colossalai/engine/ophooks/zero_hook.py index 051dd7a87..01d3a08bb 100644 --- a/colossalai/engine/ophooks/zero_hook.py +++ b/colossalai/engine/ophooks/zero_hook.py @@ -31,11 +31,11 @@ class ZeroHook(BaseOpHook): def pre_fwd_exec(self, module: torch.nn.Module, *args): tensor_list = [] - for param in module.parameters(): + for param in module.parameters(recurse=False): assert hasattr(param, 'col_attr') tensor_list.append(param.col_attr.sharded_data_tensor) self.shard_strategy.gather(tensor_list, self.process_group) - for param in module.parameters(): + for param in module.parameters(recurse=False): colo_model_data_tensor_move_inline(param.col_attr.sharded_data_tensor, self.computing_device) param.data = param.col_attr.sharded_data_tensor.payload @@ -44,20 +44,20 @@ class ZeroHook(BaseOpHook): def post_fwd_exec(self, module: torch.nn.Module, *args): tensor_list = [] - for param in module.parameters(): + for param in module.parameters(recurse=False): assert hasattr(param, 'col_attr') tensor_list.append(param.col_attr.sharded_data_tensor) self.shard_strategy.shard(tensor_list, self.process_group) - for param in module.parameters(): + for param in module.parameters(recurse=False): param.col_attr.remove_torch_payload() def pre_bwd_exec(self, module: torch.nn.Module, input, output): tensor_list = [] - for param in module.parameters(): + for param in module.parameters(recurse=False): assert hasattr(param, 'col_attr') tensor_list.append(param.col_attr.sharded_data_tensor) self.shard_strategy.gather(tensor_list, self.process_group) - for param in module.parameters(): + for param in module.parameters(recurse=False): colo_model_data_tensor_move_inline(param.col_attr.sharded_data_tensor, self.computing_device) param.data = param.col_attr.sharded_data_tensor.payload # Store local accumulated grad shard @@ -77,11 +77,11 @@ class ZeroHook(BaseOpHook): def post_bwd_exec(self, module: torch.nn.Module, input): tensor_list = [] - for param in module.parameters(): + for param in module.parameters(recurse=False): assert hasattr(param, 'col_attr') tensor_list.append(param.col_attr.sharded_data_tensor) self.shard_strategy.shard(tensor_list, self.process_group) - for param in module.parameters(): + for param in module.parameters(recurse=False): param.col_attr.remove_torch_payload() def pre_iter(self): diff --git a/colossalai/zero/init_ctx/init_context.py b/colossalai/zero/init_ctx/init_context.py index bf4a44aec..00aef1f1e 100644 --- a/colossalai/zero/init_ctx/init_context.py +++ b/colossalai/zero/init_ctx/init_context.py @@ -12,6 +12,12 @@ from torch.distributed import ProcessGroup from colossalai.logging import get_dist_logger, disable_existing_loggers +def _substitute_init_recursively(cls, func): + for subcls in cls.__subclasses__(): + _substitute_init_recursively(subcls, func) + func(subcls) + + class InsertPostInitMethodToModuleSubClasses(object): def __init__(self): @@ -41,8 +47,7 @@ class InsertPostInitMethodToModuleSubClasses(object): # Replace .__init__() for all existing subclasses of torch.nn.Module # Excution self._post_init_method after the default init function. - for subclass in torch.nn.modules.module.Module.__subclasses__(): - _enable_class(subclass) + _substitute_init_recursively(torch.nn.modules.module.Module, _enable_class) # holding on to the current __init__subclass__ for exit torch.nn.modules.module.Module._old_init_subclass = (torch.nn.modules.module.Module.__init_subclass__) @@ -57,8 +62,7 @@ class InsertPostInitMethodToModuleSubClasses(object): cls.__init__ = cls._old_init # Replace .__init__() for all existing subclasses of torch.nn.Module - for subclass in torch.nn.modules.module.Module.__subclasses__(): - _disable_class(subclass) + _substitute_init_recursively(torch.nn.modules.module.Module, _disable_class) # Replace .__init__() for future subclasses of torch.nn.Module torch.nn.modules.module.Module.__init_subclass__ = (torch.nn.modules.module.Module._old_init_subclass) @@ -144,7 +148,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses): The function to call at the end of the constructor of each module. NOTE() The module may be passed to this function multiple times. """ - for param in module.parameters(): + for param in module.parameters(recurse=False): # avoid adapting a param to ShardedParam twice if hasattr(param, 'col_attr'): continue @@ -173,7 +177,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses): # We must cast buffers # If we use BN, buffers may be on CPU and Float # We must cast them - for buffer in module.buffers(): + for buffer in module.buffers(recurse=False): buffer.data = buffer.data.to(device=torch.cuda.current_device()) if self.convert_fp16: buffer.data = cast_tensor_to_fp16(buffer.data) diff --git a/tests/components_to_test/__init__.py b/tests/components_to_test/__init__.py index ad5017911..590314de8 100644 --- a/tests/components_to_test/__init__.py +++ b/tests/components_to_test/__init__.py @@ -1 +1 @@ -from . import repeated_computed_layer, resnet, nested_model, bert +from . import repeated_computed_layer, resnet, nested_model, bert, no_leaf_module diff --git a/tests/components_to_test/no_leaf_module.py b/tests/components_to_test/no_leaf_module.py new file mode 100644 index 000000000..c944ff48f --- /dev/null +++ b/tests/components_to_test/no_leaf_module.py @@ -0,0 +1,45 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from colossalai.nn import CheckpointModule +from .utils.dummy_data_generator import DummyDataGenerator +from .registry import non_distributed_component_funcs + + +class NoLeafModule(CheckpointModule): + """ + In this no-leaf module, it has subordinate nn.modules and a nn.Parameter. + """ + + def __init__(self, checkpoint=False) -> None: + super().__init__(checkpoint=checkpoint) + self.proj1 = nn.Linear(4, 8) + self.weight = nn.Parameter(torch.randn(8, 8)) + self.proj2 = nn.Linear(8, 4) + + def forward(self, x): + x = self.proj1(x) + x = F.linear(x, self.weight) + x = self.proj2(x) + return x + + +class DummyDataLoader(DummyDataGenerator): + + def generate(self): + data = torch.rand(16, 4) + label = torch.randint(low=0, high=2, size=(16,)) + return data, label + + +@non_distributed_component_funcs.register(name='no_leaf_module') +def get_training_components(): + + def model_builder(checkpoint=True): + return NoLeafModule(checkpoint) + + trainloader = DummyDataLoader() + testloader = DummyDataLoader() + + criterion = torch.nn.CrossEntropyLoss() + return model_builder, trainloader, testloader, torch.optim.Adam, criterion diff --git a/tests/test_zero_data_parallel/test_shard_model_v2.py b/tests/test_zero_data_parallel/test_shard_model_v2.py index adaa98fc9..57109800f 100644 --- a/tests/test_zero_data_parallel/test_shard_model_v2.py +++ b/tests/test_zero_data_parallel/test_shard_model_v2.py @@ -24,7 +24,7 @@ from common import CONFIG, check_grads_padding, run_fwd_bwd @parameterize("enable_autocast", [True]) @parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy]) def run_model_test(enable_autocast, shard_strategy_class): - test_models = ['repeated_computed_layers', 'resnet18', 'bert'] + test_models = ['repeated_computed_layers', 'resnet18', 'bert', 'no_leaf_module'] shard_strategy = shard_strategy_class() for model_name in test_models: get_components_func = non_distributed_component_funcs.get_callable(model_name) diff --git a/tests/test_zero_data_parallel/test_sharded_optim_v2.py b/tests/test_zero_data_parallel/test_sharded_optim_v2.py index 29df7b667..5cb5ddae6 100644 --- a/tests/test_zero_data_parallel/test_sharded_optim_v2.py +++ b/tests/test_zero_data_parallel/test_sharded_optim_v2.py @@ -45,7 +45,7 @@ def _run_step(model, optimizer, data, label, criterion, enable_autocast=False): @parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy]) @parameterize("gpu_margin_mem_ratio", [0.0, 0.7]) def _run_test_sharded_optim_v2(cpu_offload, shard_strategy_class, use_cpuadam, gpu_margin_mem_ratio): - test_models = ['repeated_computed_layers', 'resnet18', 'bert'] + test_models = ['repeated_computed_layers', 'resnet18', 'bert', 'no_leaf_module'] shard_strategy = shard_strategy_class() if use_cpuadam and cpu_offload is False: