Revert "[zero] add ZeroTensorShardStrategy (#793)" (#806)

pull/807/head
Jiarui Fang 2022-04-19 14:40:02 +08:00 committed by GitHub
parent 88759e289e
commit e761ad2cd7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 11 additions and 214 deletions

View File

@ -1,109 +0,0 @@
#include <cuda_runtime.h>
#include <nccl.h>
#include <torch/extension.h>
#define CHECK_CUDA(x) \
TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) \
TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
#define CUDACHECK(cmd) \
do { \
cudaError_t e = cmd; \
if (e != cudaSuccess) { \
printf("Failed: Cuda error %s:%d '%s'\n", __FILE__, __LINE__, \
cudaGetErrorString(e)); \
exit(EXIT_FAILURE); \
} \
} while (0)
#define NCCLCHECK(cmd) \
do { \
ncclResult_t r = cmd; \
if (r != ncclSuccess) { \
printf("Failed, NCCL error %s:%d '%s'\n", __FILE__, __LINE__, \
ncclGetErrorString(r)); \
exit(EXIT_FAILURE); \
} \
} while (0)
class ZeroCommMgr {
public:
cudaStream_t cuda_stream;
ncclComm_t nccl_comm;
ZeroCommMgr(const ncclComm_t &comm_) {
CUDACHECK(cudaStreamCreate(&cuda_stream));
nccl_comm = comm_;
}
};
ZeroCommMgr *GMGR = nullptr;
#ifdef USE_C10D_NCCL
#include <c10d/ProcessGroupNCCL.hpp>
class HackNCCLGroup : public c10d::ProcessGroupNCCL {
public:
ncclComm_t getcomm(at::Device dev) {
ncclUniqueId ncclID;
int rank = getRank();
if (rank == 0) {
ncclGetUniqueId(&ncclID);
}
broadcastUniqueNCCLID(&ncclID, c10d::OpType::SEND, "colossal_zero_comm",
rank);
ncclComm_t comm;
NCCLCHECK(ncclCommInitRank(&comm, getSize(), ncclID, rank));
return comm;
}
};
int create_zero_comm(c10d::ProcessGroupNCCL &pg, at::Device dev) {
auto *hack_group = reinterpret_cast<HackNCCLGroup *>(&pg);
GMGR = new ZeroCommMgr(hack_group->getcomm(dev));
assert(GMGR->nccl_comm != 0);
return 0;
}
#endif
template <typename scalar_t>
void colo_all_gather_impl(scalar_t *recvbuff, int rank, int sendcount,
ncclDataType_t data_type) {
scalar_t *sendbuff = recvbuff + (rank * sendcount);
NCCLCHECK(ncclAllGather(sendbuff, recvbuff, sendcount, data_type,
GMGR->nccl_comm, GMGR->cuda_stream));
CUDACHECK(cudaStreamSynchronize(GMGR->cuda_stream));
}
int colo_all_gather(torch::Tensor &input_tensor, int rank, int world_size) {
CHECK_INPUT(input_tensor);
auto total_size = input_tensor.numel();
assert(total_size % world_size == 0);
auto sendcount = total_size / world_size;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
input_tensor.scalar_type(), "colo_all_gather", ([&] {
colo_all_gather_impl<scalar_t>(
input_tensor.data_ptr<scalar_t>(), rank, sendcount,
input_tensor.scalar_type() == at::ScalarType::Half ? ncclHalf
: ncclFloat);
}));
return 0;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
#ifdef USE_C10D_NCCL
m.def("create_comm", &create_zero_comm,
"Create the communication environment for Colossal Zero");
#endif
m.def("inplace_all_gather", &colo_all_gather,
"All gather operation used in Colossal Zero");
}

View File

@ -1 +0,0 @@
from .zero_comm import ZeroDist

View File

@ -1,46 +0,0 @@
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
from colossalai.context.singleton_meta import SingletonMeta
from colossalai.utils import get_current_device
from typing import Optional
ZERO_USE_NCCL = False
try:
import colossal_zero_comm
ZERO_USE_NCCL = True
except ImportError:
print("Please pip reinstall Colossalai.")
class ZeroCommWorld(metaclass=SingletonMeta):
"""Zero communicator, used for communications in zero parallel.
"""
def __init__(self):
super().__init__()
self.zero_pg: Optional[ProcessGroup] = None
@property
def is_initialized(self):
return self.zero_pg is not None
def zero_comm_init(self, comm_group: ProcessGroup):
if not ZERO_USE_NCCL:
return
if self.is_initialized:
assert self.zero_pg == comm_group, "Cant not initialize zero group twice"
return
self.zero_pg = comm_group
colossal_zero_comm.create_comm(self.zero_pg, get_current_device())
def zero_all_gather(self, input_tensor: torch.Tensor):
assert self.zero_pg is not None, "Please initialize zero communication world first"
rank = dist.get_rank(self.zero_pg)
world_size = self.zero_pg.size()
colossal_zero_comm.inplace_all_gather(input_tensor, rank, world_size)
ZeroDist = ZeroCommWorld()

View File

@ -12,7 +12,6 @@ from colossalai.logging import get_dist_logger
from colossalai.zero.shard_utils import BaseShardStrategy
from colossalai.zero.sharded_model._utils import cast_tensor_to_fp16
from colossalai.zero.sharded_param import ShardedParamV2
from colossalai.zero.comm import ZeroDist
from contextlib import AbstractContextManager
@ -192,7 +191,6 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
The Callback function when entering the context
"""
self.logger = get_dist_logger("ZeroInitContext")
ZeroDist.zero_comm_init(self.dp_process_group) # initialize zero communication world
# substitute fan-in and fan-out calculation
self.nn_fanin_fanout = nn.init._calculate_fan_in_and_fan_out

View File

@ -1,6 +1,5 @@
from .base_shard_strategy import BaseShardStrategy
from .bucket_tensor_shard_strategy import BucketTensorShardStrategy
from .tensor_shard_strategy import TensorShardStrategy
from .zero_tensor_shard_strategy import ZeroTensorShardStrategy
__all__ = ['BaseShardStrategy', 'TensorShardStrategy', 'BucketTensorShardStrategy', 'ZeroTensorShardStrategy']
__all__ = ['BaseShardStrategy', 'TensorShardStrategy', 'BucketTensorShardStrategy']

View File

@ -1,38 +0,0 @@
from typing import Optional
import torch
import torch.distributed as dist
from colossalai.utils import get_current_device
from colossalai.zero.sharded_param.tensor_utils import colo_model_data_tensor_move_inline
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
from colossalai.zero.comm import ZeroDist
from .tensor_shard_strategy import TensorShardStrategy
class ZeroTensorShardStrategy(TensorShardStrategy):
"""Use the same shard scheme as `TensorShardStrategy`'s.
But its all-gather operation is in-place, meaning that no extra buffer is created.
Extra buffer is created when using `torch.distributed.all_gather`.
This can reduce peak memory used in zero-offload.
You should notice that this strategy is highly coupled with zero.
You can not change its communication group and must use ZeroContext to create your model.
"""
def _gather_tensor(self, t: ShardedTensor, process_group: Optional[dist.ProcessGroup] = None):
if not t.is_sharded:
return
target_device = t.device
payload_numel = t.payload.numel()
world_size = dist.get_world_size(process_group)
rank = dist.get_rank(process_group)
buffer = torch.empty(payload_numel * world_size, dtype=t.payload.dtype, device=get_current_device())
buffer_list = list(torch.chunk(buffer, chunks=world_size, dim=0))
buffer_list[rank].copy_(t.payload)
ZeroDist.zero_all_gather(buffer) # notice: process_group is useless here
gathered_payload = torch.narrow(buffer, 0, 0, t.origin_numel).reshape(t.origin_shape)
t.reset_payload(gathered_payload)
colo_model_data_tensor_move_inline(t, target_device)
t.is_sharded = False

View File

@ -134,12 +134,6 @@ if build_cuda_ext:
'nvcc': append_nvcc_threads(['-O3', '--use_fast_math'] + version_dependent_macros + extra_cuda_flags)
})
ext_modules.append(
cuda_ext_helper(name='colossal_zero_comm',
sources=['zero_comm.cpp'],
extra_cuda_flags=['-DUSE_C10D_NCCL'],
extra_cxx_flags=['-DUSE_C10D_NCCL']))
ext_modules.append(
cuda_ext_helper('colossal_C', [
'colossal_C_frontend.cpp', 'multi_tensor_sgd_kernel.cu', 'multi_tensor_scale_kernel.cu',

View File

@ -9,7 +9,7 @@ from colossalai.nn.optimizer import HybridAdam
from colossalai.testing import parameterize, rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.zero.init_ctx import ZeroInitContext
from colossalai.zero.shard_utils import ZeroTensorShardStrategy
from colossalai.zero.shard_utils import BucketTensorShardStrategy
from colossalai.zero.sharded_model import ShardedModelV2
from colossalai.zero.sharded_optim import ShardedOptimizerV2
from colossalai.zero.sharded_optim._utils import has_inf_or_nan
@ -20,7 +20,7 @@ from common import CONFIG
@parameterize("cpu_offload", [True, False])
@parameterize("shard_strategy_class", [ZeroTensorShardStrategy])
@parameterize("shard_strategy_class", [BucketTensorShardStrategy])
@parameterize("gpu_margin_mem_ratio", [0.0, 0.7])
def _run_test_found_inf(cpu_offload, shard_strategy_class, gpu_margin_mem_ratio):
test_models = ['repeated_computed_layers']

View File

@ -15,14 +15,14 @@ from colossalai.gemini.memory_tracer.model_data_memtracer import \
colo_model_mem_usage
from colossalai.utils.memory import colo_device_memory_used
from colossalai.zero.init_ctx import ZeroInitContext
from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy, ZeroTensorShardStrategy)
from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
from tests.components_to_test.registry import non_distributed_component_funcs
from common import CONFIG
@parameterize("init_device_type", ['cpu', 'cuda'])
@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy, ZeroTensorShardStrategy])
@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy])
def run_model_test(init_device_type, shard_strategy_class):
logger = get_dist_logger("test_zero_init")

View File

@ -8,7 +8,7 @@ from colossalai.utils.cuda import get_current_device
from colossalai.utils.memory import colo_device_memory_capacity, colo_set_process_memory_fraction
from colossalai.zero.init_ctx import ZeroInitContext
from colossalai.zero.sharded_model import ShardedModelV2
from colossalai.zero.shard_utils import ZeroTensorShardStrategy
from colossalai.zero.shard_utils import BucketTensorShardStrategy
from colossalai.utils import free_port
from colossalai.testing import rerun_if_address_is_in_use
from functools import partial
@ -35,7 +35,7 @@ def run_mem_collector_testing():
fraction = (50 * 1024**2) / cuda_capacity
# limit max memory to 50MB
colo_set_process_memory_fraction(fraction)
shard_strategy = ZeroTensorShardStrategy()
shard_strategy = BucketTensorShardStrategy()
with ZeroInitContext(target_device=get_current_device(), shard_strategy=shard_strategy, shard_param=True):
model = MyTestModel()

View File

@ -10,7 +10,7 @@ import torch.multiprocessing as mp
from colossalai.testing import parameterize, rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.zero.init_ctx import ZeroInitContext
from colossalai.zero.shard_utils import (BucketTensorShardStrategy, ZeroTensorShardStrategy)
from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
from colossalai.zero.sharded_model import ShardedModelV2
from colossalai.zero.sharded_model._utils import cast_tensor_to_fp16
from colossalai.zero.sharded_model.utils import col_model_deepcopy
@ -21,7 +21,7 @@ from common import CONFIG, check_grads_padding, run_fwd_bwd
@parameterize("enable_autocast", [True])
@parameterize("shard_strategy_class", [ZeroTensorShardStrategy, BucketTensorShardStrategy])
@parameterize("shard_strategy_class", [BucketTensorShardStrategy])
def run_model_test(enable_autocast, shard_strategy_class):
test_models = ['repeated_computed_layers', 'resnet18', 'bert', 'no_leaf_module']
shard_strategy = shard_strategy_class()

View File

@ -11,7 +11,7 @@ import torch.multiprocessing as mp
from colossalai.testing import parameterize, rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.zero.init_ctx import ZeroInitContext
from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy, ZeroTensorShardStrategy)
from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
from colossalai.zero.sharded_model import ShardedModelV2
from colossalai.zero.sharded_model.utils import col_model_deepcopy
from tests.components_to_test.registry import non_distributed_component_funcs
@ -19,7 +19,7 @@ from tests.components_to_test.registry import non_distributed_component_funcs
from common import CONFIG
@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy, ZeroTensorShardStrategy])
@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy])
def run_zero_state_dict(shard_strategy_class):
test_models = ['repeated_computed_layers', 'resnet18']
shard_strategy = shard_strategy_class()