From 88759e289efd0a7b5e0d7bf8e01dbe29db85cf71 Mon Sep 17 00:00:00 2001 From: HELSON Date: Tue, 19 Apr 2022 14:32:45 +0800 Subject: [PATCH] [zero] add ZeroTensorShardStrategy (#793) --- .../kernel/cuda_native/csrc/zero_comm.cpp | 109 ++++++++++++++++++ colossalai/zero/comm/__init__.py | 1 + colossalai/zero/comm/zero_comm.py | 46 ++++++++ colossalai/zero/init_ctx/init_context.py | 2 + colossalai/zero/shard_utils/__init__.py | 3 +- .../shard_utils/zero_tensor_shard_strategy.py | 38 ++++++ setup.py | 6 + tests/test_zero/test_found_inf.py | 4 +- tests/test_zero/test_init_context.py | 4 +- tests/test_zero/test_mem_collector.py | 4 +- tests/test_zero/test_shard_model_v2.py | 4 +- tests/test_zero/test_state_dict.py | 4 +- 12 files changed, 214 insertions(+), 11 deletions(-) create mode 100644 colossalai/kernel/cuda_native/csrc/zero_comm.cpp create mode 100644 colossalai/zero/comm/__init__.py create mode 100644 colossalai/zero/comm/zero_comm.py create mode 100644 colossalai/zero/shard_utils/zero_tensor_shard_strategy.py diff --git a/colossalai/kernel/cuda_native/csrc/zero_comm.cpp b/colossalai/kernel/cuda_native/csrc/zero_comm.cpp new file mode 100644 index 000000000..e07d6f504 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/zero_comm.cpp @@ -0,0 +1,109 @@ +#include +#include +#include + +#define CHECK_CUDA(x) \ + TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) \ + TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) \ + CHECK_CUDA(x); \ + CHECK_CONTIGUOUS(x) + +#define CUDACHECK(cmd) \ + do { \ + cudaError_t e = cmd; \ + if (e != cudaSuccess) { \ + printf("Failed: Cuda error %s:%d '%s'\n", __FILE__, __LINE__, \ + cudaGetErrorString(e)); \ + exit(EXIT_FAILURE); \ + } \ + } while (0) + +#define NCCLCHECK(cmd) \ + do { \ + ncclResult_t r = cmd; \ + if (r != ncclSuccess) { \ + printf("Failed, NCCL error %s:%d '%s'\n", __FILE__, __LINE__, \ + ncclGetErrorString(r)); \ + exit(EXIT_FAILURE); \ + } \ + } while (0) + +class ZeroCommMgr { + public: + cudaStream_t cuda_stream; + ncclComm_t nccl_comm; + + ZeroCommMgr(const ncclComm_t &comm_) { + CUDACHECK(cudaStreamCreate(&cuda_stream)); + nccl_comm = comm_; + } +}; + +ZeroCommMgr *GMGR = nullptr; + +#ifdef USE_C10D_NCCL +#include + +class HackNCCLGroup : public c10d::ProcessGroupNCCL { + public: + ncclComm_t getcomm(at::Device dev) { + ncclUniqueId ncclID; + int rank = getRank(); + if (rank == 0) { + ncclGetUniqueId(&ncclID); + } + + broadcastUniqueNCCLID(&ncclID, c10d::OpType::SEND, "colossal_zero_comm", + rank); + + ncclComm_t comm; + NCCLCHECK(ncclCommInitRank(&comm, getSize(), ncclID, rank)); + return comm; + } +}; + +int create_zero_comm(c10d::ProcessGroupNCCL &pg, at::Device dev) { + auto *hack_group = reinterpret_cast(&pg); + GMGR = new ZeroCommMgr(hack_group->getcomm(dev)); + assert(GMGR->nccl_comm != 0); + return 0; +} +#endif + +template +void colo_all_gather_impl(scalar_t *recvbuff, int rank, int sendcount, + ncclDataType_t data_type) { + scalar_t *sendbuff = recvbuff + (rank * sendcount); + NCCLCHECK(ncclAllGather(sendbuff, recvbuff, sendcount, data_type, + GMGR->nccl_comm, GMGR->cuda_stream)); + CUDACHECK(cudaStreamSynchronize(GMGR->cuda_stream)); +} + +int colo_all_gather(torch::Tensor &input_tensor, int rank, int world_size) { + CHECK_INPUT(input_tensor); + + auto total_size = input_tensor.numel(); + assert(total_size % world_size == 0); + auto sendcount = total_size / world_size; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + input_tensor.scalar_type(), "colo_all_gather", ([&] { + colo_all_gather_impl( + input_tensor.data_ptr(), rank, sendcount, + input_tensor.scalar_type() == at::ScalarType::Half ? ncclHalf + : ncclFloat); + })); + + return 0; +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { +#ifdef USE_C10D_NCCL + m.def("create_comm", &create_zero_comm, + "Create the communication environment for Colossal Zero"); +#endif + m.def("inplace_all_gather", &colo_all_gather, + "All gather operation used in Colossal Zero"); +} diff --git a/colossalai/zero/comm/__init__.py b/colossalai/zero/comm/__init__.py new file mode 100644 index 000000000..16b2d3e02 --- /dev/null +++ b/colossalai/zero/comm/__init__.py @@ -0,0 +1 @@ +from .zero_comm import ZeroDist diff --git a/colossalai/zero/comm/zero_comm.py b/colossalai/zero/comm/zero_comm.py new file mode 100644 index 000000000..a2d54a015 --- /dev/null +++ b/colossalai/zero/comm/zero_comm.py @@ -0,0 +1,46 @@ +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup +from colossalai.context.singleton_meta import SingletonMeta +from colossalai.utils import get_current_device +from typing import Optional + +ZERO_USE_NCCL = False +try: + import colossal_zero_comm + ZERO_USE_NCCL = True +except ImportError: + print("Please pip reinstall Colossalai.") + + +class ZeroCommWorld(metaclass=SingletonMeta): + """Zero communicator, used for communications in zero parallel. + """ + + def __init__(self): + super().__init__() + self.zero_pg: Optional[ProcessGroup] = None + + @property + def is_initialized(self): + return self.zero_pg is not None + + def zero_comm_init(self, comm_group: ProcessGroup): + if not ZERO_USE_NCCL: + return + + if self.is_initialized: + assert self.zero_pg == comm_group, "Cant not initialize zero group twice" + return + + self.zero_pg = comm_group + colossal_zero_comm.create_comm(self.zero_pg, get_current_device()) + + def zero_all_gather(self, input_tensor: torch.Tensor): + assert self.zero_pg is not None, "Please initialize zero communication world first" + rank = dist.get_rank(self.zero_pg) + world_size = self.zero_pg.size() + colossal_zero_comm.inplace_all_gather(input_tensor, rank, world_size) + + +ZeroDist = ZeroCommWorld() diff --git a/colossalai/zero/init_ctx/init_context.py b/colossalai/zero/init_ctx/init_context.py index c27d7a577..7a5bf6f4c 100644 --- a/colossalai/zero/init_ctx/init_context.py +++ b/colossalai/zero/init_ctx/init_context.py @@ -12,6 +12,7 @@ from colossalai.logging import get_dist_logger from colossalai.zero.shard_utils import BaseShardStrategy from colossalai.zero.sharded_model._utils import cast_tensor_to_fp16 from colossalai.zero.sharded_param import ShardedParamV2 +from colossalai.zero.comm import ZeroDist from contextlib import AbstractContextManager @@ -191,6 +192,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses): The Callback function when entering the context """ self.logger = get_dist_logger("ZeroInitContext") + ZeroDist.zero_comm_init(self.dp_process_group) # initialize zero communication world # substitute fan-in and fan-out calculation self.nn_fanin_fanout = nn.init._calculate_fan_in_and_fan_out diff --git a/colossalai/zero/shard_utils/__init__.py b/colossalai/zero/shard_utils/__init__.py index 5e5d63a7e..6a0e0c59b 100644 --- a/colossalai/zero/shard_utils/__init__.py +++ b/colossalai/zero/shard_utils/__init__.py @@ -1,5 +1,6 @@ from .base_shard_strategy import BaseShardStrategy from .bucket_tensor_shard_strategy import BucketTensorShardStrategy from .tensor_shard_strategy import TensorShardStrategy +from .zero_tensor_shard_strategy import ZeroTensorShardStrategy -__all__ = ['BaseShardStrategy', 'TensorShardStrategy', 'BucketTensorShardStrategy'] +__all__ = ['BaseShardStrategy', 'TensorShardStrategy', 'BucketTensorShardStrategy', 'ZeroTensorShardStrategy'] diff --git a/colossalai/zero/shard_utils/zero_tensor_shard_strategy.py b/colossalai/zero/shard_utils/zero_tensor_shard_strategy.py new file mode 100644 index 000000000..afd2f619c --- /dev/null +++ b/colossalai/zero/shard_utils/zero_tensor_shard_strategy.py @@ -0,0 +1,38 @@ +from typing import Optional + +import torch +import torch.distributed as dist +from colossalai.utils import get_current_device +from colossalai.zero.sharded_param.tensor_utils import colo_model_data_tensor_move_inline +from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor +from colossalai.zero.comm import ZeroDist + +from .tensor_shard_strategy import TensorShardStrategy + + +class ZeroTensorShardStrategy(TensorShardStrategy): + """Use the same shard scheme as `TensorShardStrategy`'s. + But its all-gather operation is in-place, meaning that no extra buffer is created. + Extra buffer is created when using `torch.distributed.all_gather`. + This can reduce peak memory used in zero-offload. + You should notice that this strategy is highly coupled with zero. + You can not change its communication group and must use ZeroContext to create your model. + """ + + def _gather_tensor(self, t: ShardedTensor, process_group: Optional[dist.ProcessGroup] = None): + if not t.is_sharded: + return + target_device = t.device + payload_numel = t.payload.numel() + world_size = dist.get_world_size(process_group) + rank = dist.get_rank(process_group) + + buffer = torch.empty(payload_numel * world_size, dtype=t.payload.dtype, device=get_current_device()) + buffer_list = list(torch.chunk(buffer, chunks=world_size, dim=0)) + buffer_list[rank].copy_(t.payload) + + ZeroDist.zero_all_gather(buffer) # notice: process_group is useless here + gathered_payload = torch.narrow(buffer, 0, 0, t.origin_numel).reshape(t.origin_shape) + t.reset_payload(gathered_payload) + colo_model_data_tensor_move_inline(t, target_device) + t.is_sharded = False diff --git a/setup.py b/setup.py index 12b12c31d..c8817c82a 100644 --- a/setup.py +++ b/setup.py @@ -134,6 +134,12 @@ if build_cuda_ext: 'nvcc': append_nvcc_threads(['-O3', '--use_fast_math'] + version_dependent_macros + extra_cuda_flags) }) + ext_modules.append( + cuda_ext_helper(name='colossal_zero_comm', + sources=['zero_comm.cpp'], + extra_cuda_flags=['-DUSE_C10D_NCCL'], + extra_cxx_flags=['-DUSE_C10D_NCCL'])) + ext_modules.append( cuda_ext_helper('colossal_C', [ 'colossal_C_frontend.cpp', 'multi_tensor_sgd_kernel.cu', 'multi_tensor_scale_kernel.cu', diff --git a/tests/test_zero/test_found_inf.py b/tests/test_zero/test_found_inf.py index 34283f501..af1b2e670 100644 --- a/tests/test_zero/test_found_inf.py +++ b/tests/test_zero/test_found_inf.py @@ -9,7 +9,7 @@ from colossalai.nn.optimizer import HybridAdam from colossalai.testing import parameterize, rerun_if_address_is_in_use from colossalai.utils import free_port from colossalai.zero.init_ctx import ZeroInitContext -from colossalai.zero.shard_utils import BucketTensorShardStrategy +from colossalai.zero.shard_utils import ZeroTensorShardStrategy from colossalai.zero.sharded_model import ShardedModelV2 from colossalai.zero.sharded_optim import ShardedOptimizerV2 from colossalai.zero.sharded_optim._utils import has_inf_or_nan @@ -20,7 +20,7 @@ from common import CONFIG @parameterize("cpu_offload", [True, False]) -@parameterize("shard_strategy_class", [BucketTensorShardStrategy]) +@parameterize("shard_strategy_class", [ZeroTensorShardStrategy]) @parameterize("gpu_margin_mem_ratio", [0.0, 0.7]) def _run_test_found_inf(cpu_offload, shard_strategy_class, gpu_margin_mem_ratio): test_models = ['repeated_computed_layers'] diff --git a/tests/test_zero/test_init_context.py b/tests/test_zero/test_init_context.py index b955e4852..61bec973c 100644 --- a/tests/test_zero/test_init_context.py +++ b/tests/test_zero/test_init_context.py @@ -15,14 +15,14 @@ from colossalai.gemini.memory_tracer.model_data_memtracer import \ colo_model_mem_usage from colossalai.utils.memory import colo_device_memory_used from colossalai.zero.init_ctx import ZeroInitContext -from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy) +from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy, ZeroTensorShardStrategy) from tests.components_to_test.registry import non_distributed_component_funcs from common import CONFIG @parameterize("init_device_type", ['cpu', 'cuda']) -@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy]) +@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy, ZeroTensorShardStrategy]) def run_model_test(init_device_type, shard_strategy_class): logger = get_dist_logger("test_zero_init") diff --git a/tests/test_zero/test_mem_collector.py b/tests/test_zero/test_mem_collector.py index bea971935..d311e0f37 100644 --- a/tests/test_zero/test_mem_collector.py +++ b/tests/test_zero/test_mem_collector.py @@ -8,7 +8,7 @@ from colossalai.utils.cuda import get_current_device from colossalai.utils.memory import colo_device_memory_capacity, colo_set_process_memory_fraction from colossalai.zero.init_ctx import ZeroInitContext from colossalai.zero.sharded_model import ShardedModelV2 -from colossalai.zero.shard_utils import BucketTensorShardStrategy +from colossalai.zero.shard_utils import ZeroTensorShardStrategy from colossalai.utils import free_port from colossalai.testing import rerun_if_address_is_in_use from functools import partial @@ -35,7 +35,7 @@ def run_mem_collector_testing(): fraction = (50 * 1024**2) / cuda_capacity # limit max memory to 50MB colo_set_process_memory_fraction(fraction) - shard_strategy = BucketTensorShardStrategy() + shard_strategy = ZeroTensorShardStrategy() with ZeroInitContext(target_device=get_current_device(), shard_strategy=shard_strategy, shard_param=True): model = MyTestModel() diff --git a/tests/test_zero/test_shard_model_v2.py b/tests/test_zero/test_shard_model_v2.py index 654c82a46..62b860c23 100644 --- a/tests/test_zero/test_shard_model_v2.py +++ b/tests/test_zero/test_shard_model_v2.py @@ -10,7 +10,7 @@ import torch.multiprocessing as mp from colossalai.testing import parameterize, rerun_if_address_is_in_use from colossalai.utils import free_port from colossalai.zero.init_ctx import ZeroInitContext -from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy) +from colossalai.zero.shard_utils import (BucketTensorShardStrategy, ZeroTensorShardStrategy) from colossalai.zero.sharded_model import ShardedModelV2 from colossalai.zero.sharded_model._utils import cast_tensor_to_fp16 from colossalai.zero.sharded_model.utils import col_model_deepcopy @@ -21,7 +21,7 @@ from common import CONFIG, check_grads_padding, run_fwd_bwd @parameterize("enable_autocast", [True]) -@parameterize("shard_strategy_class", [BucketTensorShardStrategy]) +@parameterize("shard_strategy_class", [ZeroTensorShardStrategy, BucketTensorShardStrategy]) def run_model_test(enable_autocast, shard_strategy_class): test_models = ['repeated_computed_layers', 'resnet18', 'bert', 'no_leaf_module'] shard_strategy = shard_strategy_class() diff --git a/tests/test_zero/test_state_dict.py b/tests/test_zero/test_state_dict.py index 188bc5968..18d4c983a 100644 --- a/tests/test_zero/test_state_dict.py +++ b/tests/test_zero/test_state_dict.py @@ -11,7 +11,7 @@ import torch.multiprocessing as mp from colossalai.testing import parameterize, rerun_if_address_is_in_use from colossalai.utils import free_port from colossalai.zero.init_ctx import ZeroInitContext -from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy) +from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy, ZeroTensorShardStrategy) from colossalai.zero.sharded_model import ShardedModelV2 from colossalai.zero.sharded_model.utils import col_model_deepcopy from tests.components_to_test.registry import non_distributed_component_funcs @@ -19,7 +19,7 @@ from tests.components_to_test.registry import non_distributed_component_funcs from common import CONFIG -@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy]) +@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy, ZeroTensorShardStrategy]) def run_zero_state_dict(shard_strategy_class): test_models = ['repeated_computed_layers', 'resnet18'] shard_strategy = shard_strategy_class()