From 1f90a3b1295d11988730374b9bd8fe7f7f463ae9 Mon Sep 17 00:00:00 2001 From: ver217 Date: Tue, 29 Mar 2022 09:09:04 +0800 Subject: [PATCH] [zero] polish ZeroInitContext (#540) --- colossalai/zero/init_ctx/init_context.py | 21 +++++++------------ .../test_init_context.py | 14 ++++++------- .../test_shard_model_v2.py | 6 ++---- .../test_sharded_optim_v2.py | 4 +--- .../test_sharded_optim_with_sync_bn.py | 5 ++--- .../test_state_dict.py | 6 ++---- .../test_zero_engine.py | 5 ++--- 7 files changed, 23 insertions(+), 38 deletions(-) diff --git a/colossalai/zero/init_ctx/init_context.py b/colossalai/zero/init_ctx/init_context.py index 00aef1f1e..9d5812d1a 100644 --- a/colossalai/zero/init_ctx/init_context.py +++ b/colossalai/zero/init_ctx/init_context.py @@ -4,12 +4,11 @@ from typing import Optional import torch from colossalai.context.parallel_mode import ParallelMode from colossalai.core import global_context as gpc -from colossalai.utils.memory_utils.memory_monitor import colo_cuda_memory_used +from colossalai.logging import get_dist_logger from colossalai.zero.shard_utils import BaseShardStrategy from colossalai.zero.sharded_model._utils import cast_tensor_to_fp16 from colossalai.zero.sharded_param import ShardedParamV2 from torch.distributed import ProcessGroup -from colossalai.logging import get_dist_logger, disable_existing_loggers def _substitute_init_recursively(cls, func): @@ -107,20 +106,16 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses): """ def __init__(self, - convert_fp16: bool, target_device: torch.device, shard_strategy: BaseShardStrategy, shard_param: bool = False, - shard_grad: bool = False, rm_torch_payload_on_the_fly: bool = False, - model_numel_tensor: torch.Tensor = torch.zeros(1, dtype=torch.int), + model_numel_tensor: torch.Tensor = torch.zeros(1, dtype=torch.long), dp_process_group: Optional[ProcessGroup] = None): super().__init__() - self.convert_fp16 = convert_fp16 self.target_device = target_device self.shard_param = shard_param - self.shard_grad = shard_grad self.shard_strategy = shard_strategy self.rm_torch_payload_on_the_fly = rm_torch_payload_on_the_fly self.initialized_param_list = [] @@ -157,11 +152,10 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses): target_device = self.target_device - # convert to fp16 if necessary - if self.convert_fp16: - param.data = param.data.to(torch.half) - if param.grad is not None: - param.grad = param.grad.to(torch.half) + # convert to fp16 + param.data = param.data.to(torch.half) + if param.grad is not None: + param.grad = param.grad.to(torch.half) # move torch parameters to the target device param.data = param.data.to(target_device) @@ -179,5 +173,4 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses): # We must cast them for buffer in module.buffers(recurse=False): buffer.data = buffer.data.to(device=torch.cuda.current_device()) - if self.convert_fp16: - buffer.data = cast_tensor_to_fp16(buffer.data) + buffer.data = cast_tensor_to_fp16(buffer.data) diff --git a/tests/test_zero_data_parallel/test_init_context.py b/tests/test_zero_data_parallel/test_init_context.py index 46eb1dbd0..941110a55 100644 --- a/tests/test_zero_data_parallel/test_init_context.py +++ b/tests/test_zero_data_parallel/test_init_context.py @@ -7,16 +7,17 @@ import colossalai import pytest import torch import torch.multiprocessing as mp -from colossalai.testing import parameterize +from colossalai.logging import get_dist_logger +from colossalai.testing import parameterize, rerun_on_exception from colossalai.utils import free_port from colossalai.utils.cuda import get_current_device -from colossalai.utils.memory_tracer.model_data_memtracer import col_model_data_mem_usage -from colossalai.zero.init_ctx import ZeroInitContext +from colossalai.utils.memory_tracer.model_data_memtracer import \ + col_model_data_mem_usage from colossalai.utils.memory_utils.memory_monitor import colo_cuda_memory_used +from colossalai.zero.init_ctx import ZeroInitContext from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy) -from colossalai.testing import rerun_on_exception from tests.components_to_test.registry import non_distributed_component_funcs -from colossalai.logging import get_dist_logger + from common import CONFIG @@ -36,8 +37,7 @@ def run_model_test(init_device_type, shard_strategy_class): continue model_numel_tensor = torch.zeros(1, dtype=torch.int) - with ZeroInitContext(convert_fp16=True, - target_device=init_device, + with ZeroInitContext(target_device=init_device, shard_strategy=shard_strategy_class(), shard_param=True, model_numel_tensor=model_numel_tensor, diff --git a/tests/test_zero_data_parallel/test_shard_model_v2.py b/tests/test_zero_data_parallel/test_shard_model_v2.py index 57109800f..fa889bcde 100644 --- a/tests/test_zero_data_parallel/test_shard_model_v2.py +++ b/tests/test_zero_data_parallel/test_shard_model_v2.py @@ -7,14 +7,13 @@ import colossalai import pytest import torch import torch.multiprocessing as mp -from colossalai.testing import parameterize +from colossalai.testing import parameterize, rerun_on_exception from colossalai.utils import free_port from colossalai.zero.init_ctx import ZeroInitContext from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy) from colossalai.zero.sharded_model import ShardedModelV2 from colossalai.zero.sharded_model._utils import cast_tensor_to_fp16 from colossalai.zero.sharded_model.utils import col_model_deepcopy -from colossalai.testing import rerun_on_exception from tests.components_to_test.registry import non_distributed_component_funcs from torch.nn.parallel import DistributedDataParallel as DDP @@ -32,8 +31,7 @@ def run_model_test(enable_autocast, shard_strategy_class): rm_torch_payload_on_the_fly = False - with ZeroInitContext(convert_fp16=True, - target_device=torch.cuda.current_device(), + with ZeroInitContext(target_device=torch.cuda.current_device(), shard_strategy=shard_strategy, shard_param=True, rm_torch_payload_on_the_fly=rm_torch_payload_on_the_fly): diff --git a/tests/test_zero_data_parallel/test_sharded_optim_v2.py b/tests/test_zero_data_parallel/test_sharded_optim_v2.py index 76669e94d..816b846fa 100644 --- a/tests/test_zero_data_parallel/test_sharded_optim_v2.py +++ b/tests/test_zero_data_parallel/test_sharded_optim_v2.py @@ -8,7 +8,7 @@ import torch.distributed as dist import torch.multiprocessing as mp from colossalai.amp import convert_to_apex_amp from colossalai.nn.optimizer import CPUAdam -from colossalai.testing import parameterize +from colossalai.testing import parameterize, rerun_on_exception from colossalai.utils import free_port from colossalai.zero.init_ctx import ZeroInitContext from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy) @@ -16,7 +16,6 @@ from colossalai.zero.sharded_model import ShardedModelV2 from colossalai.zero.sharded_model.utils import col_model_deepcopy from colossalai.zero.sharded_optim import ShardedOptimizerV2 from colossalai.zero.sharded_optim._utils import has_inf_or_nan -from colossalai.testing import rerun_on_exception from tests.components_to_test.registry import non_distributed_component_funcs from torch.nn.parallel import DistributedDataParallel as DDP @@ -59,7 +58,6 @@ def _run_test_sharded_optim_v2(cpu_offload, shard_strategy_class, use_cpuadam, g model_builder, train_dataloader, _, optimizer_class, criterion = get_components_func() with ZeroInitContext( - convert_fp16=True, target_device=torch.device(f'cpu:0') if cpu_offload else torch.device(f'cuda:{get_current_device()}'), shard_strategy=shard_strategy, shard_param=True, diff --git a/tests/test_zero_data_parallel/test_sharded_optim_with_sync_bn.py b/tests/test_zero_data_parallel/test_sharded_optim_with_sync_bn.py index 55b027e5f..7a52e437a 100644 --- a/tests/test_zero_data_parallel/test_sharded_optim_with_sync_bn.py +++ b/tests/test_zero_data_parallel/test_sharded_optim_with_sync_bn.py @@ -10,11 +10,11 @@ import torch.distributed as dist import torch.multiprocessing as mp from colossalai.context.parallel_mode import ParallelMode from colossalai.core import global_context as gpc +from colossalai.testing import rerun_on_exception from colossalai.utils import free_port from colossalai.zero.init_ctx import ZeroInitContext from colossalai.zero.shard_utils import TensorShardStrategy from torchvision.models import resnet50 -from colossalai.testing import rerun_on_exception def run_dist(rank, world_size, port): @@ -30,8 +30,7 @@ def run_dist(rank, world_size, port): port=port, backend='nccl') - with ZeroInitContext(convert_fp16=True, - target_device=torch.cuda.current_device(), + with ZeroInitContext(target_device=torch.cuda.current_device(), shard_strategy=gpc.config.zero.model_config.shard_strategy, shard_param=True): model = resnet50() diff --git a/tests/test_zero_data_parallel/test_state_dict.py b/tests/test_zero_data_parallel/test_state_dict.py index 15488e0f7..fd1cc2f07 100644 --- a/tests/test_zero_data_parallel/test_state_dict.py +++ b/tests/test_zero_data_parallel/test_state_dict.py @@ -8,13 +8,12 @@ import colossalai import pytest import torch import torch.multiprocessing as mp -from colossalai.testing import parameterize +from colossalai.testing import parameterize, rerun_on_exception from colossalai.utils import free_port from colossalai.zero.init_ctx import ZeroInitContext from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy) from colossalai.zero.sharded_model import ShardedModelV2 from colossalai.zero.sharded_model.utils import col_model_deepcopy -from colossalai.testing import rerun_on_exception from tests.components_to_test.registry import non_distributed_component_funcs from common import CONFIG @@ -28,8 +27,7 @@ def run_zero_state_dict(shard_strategy_class): get_components_func = non_distributed_component_funcs.get_callable(model_name) model_builder, train_dataloader, test_dataloader, optimizer, criterion = get_components_func() - with ZeroInitContext(convert_fp16=True, - target_device=torch.cuda.current_device(), + with ZeroInitContext(target_device=torch.cuda.current_device(), shard_strategy=shard_strategy, shard_param=True, rm_torch_payload_on_the_fly=False): diff --git a/tests/test_zero_data_parallel/test_zero_engine.py b/tests/test_zero_data_parallel/test_zero_engine.py index a5c230777..50153427c 100644 --- a/tests/test_zero_data_parallel/test_zero_engine.py +++ b/tests/test_zero_data_parallel/test_zero_engine.py @@ -9,11 +9,11 @@ import torch import torch.distributed as dist import torch.multiprocessing as mp from colossalai.core import global_context as gpc +from colossalai.testing import rerun_on_exception from colossalai.utils import free_port from colossalai.zero.init_ctx import ZeroInitContext from colossalai.zero.sharded_model.utils import col_model_deepcopy from colossalai.zero.sharded_optim._utils import has_inf_or_nan -from colossalai.testing import rerun_on_exception from tests.components_to_test.registry import non_distributed_component_funcs from torch.nn.parallel import DistributedDataParallel as DDP @@ -32,8 +32,7 @@ def run_dist(rank, world_size, port, parallel_config): for model_name in test_models: get_components_func = non_distributed_component_funcs.get_callable(model_name) model_builder, train_dataloader, _, optimizer_class, criterion = get_components_func() - with ZeroInitContext(convert_fp16=hasattr(gpc.config, 'fp16'), - target_device=torch.cuda.current_device(), + with ZeroInitContext(target_device=torch.cuda.current_device(), shard_strategy=gpc.config.zero.model_config.shard_strategy, shard_param=True): colo_model = model_builder(checkpoint=True)