mirror of https://github.com/hpcaitech/ColossalAI
parent
88759e289e
commit
e761ad2cd7
|
@ -1,109 +0,0 @@
|
|||
#include <cuda_runtime.h>
|
||||
#include <nccl.h>
|
||||
#include <torch/extension.h>
|
||||
|
||||
#define CHECK_CUDA(x) \
|
||||
TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
|
||||
#define CHECK_CONTIGUOUS(x) \
|
||||
TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
||||
#define CHECK_INPUT(x) \
|
||||
CHECK_CUDA(x); \
|
||||
CHECK_CONTIGUOUS(x)
|
||||
|
||||
#define CUDACHECK(cmd) \
|
||||
do { \
|
||||
cudaError_t e = cmd; \
|
||||
if (e != cudaSuccess) { \
|
||||
printf("Failed: Cuda error %s:%d '%s'\n", __FILE__, __LINE__, \
|
||||
cudaGetErrorString(e)); \
|
||||
exit(EXIT_FAILURE); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#define NCCLCHECK(cmd) \
|
||||
do { \
|
||||
ncclResult_t r = cmd; \
|
||||
if (r != ncclSuccess) { \
|
||||
printf("Failed, NCCL error %s:%d '%s'\n", __FILE__, __LINE__, \
|
||||
ncclGetErrorString(r)); \
|
||||
exit(EXIT_FAILURE); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
class ZeroCommMgr {
|
||||
public:
|
||||
cudaStream_t cuda_stream;
|
||||
ncclComm_t nccl_comm;
|
||||
|
||||
ZeroCommMgr(const ncclComm_t &comm_) {
|
||||
CUDACHECK(cudaStreamCreate(&cuda_stream));
|
||||
nccl_comm = comm_;
|
||||
}
|
||||
};
|
||||
|
||||
ZeroCommMgr *GMGR = nullptr;
|
||||
|
||||
#ifdef USE_C10D_NCCL
|
||||
#include <c10d/ProcessGroupNCCL.hpp>
|
||||
|
||||
class HackNCCLGroup : public c10d::ProcessGroupNCCL {
|
||||
public:
|
||||
ncclComm_t getcomm(at::Device dev) {
|
||||
ncclUniqueId ncclID;
|
||||
int rank = getRank();
|
||||
if (rank == 0) {
|
||||
ncclGetUniqueId(&ncclID);
|
||||
}
|
||||
|
||||
broadcastUniqueNCCLID(&ncclID, c10d::OpType::SEND, "colossal_zero_comm",
|
||||
rank);
|
||||
|
||||
ncclComm_t comm;
|
||||
NCCLCHECK(ncclCommInitRank(&comm, getSize(), ncclID, rank));
|
||||
return comm;
|
||||
}
|
||||
};
|
||||
|
||||
int create_zero_comm(c10d::ProcessGroupNCCL &pg, at::Device dev) {
|
||||
auto *hack_group = reinterpret_cast<HackNCCLGroup *>(&pg);
|
||||
GMGR = new ZeroCommMgr(hack_group->getcomm(dev));
|
||||
assert(GMGR->nccl_comm != 0);
|
||||
return 0;
|
||||
}
|
||||
#endif
|
||||
|
||||
template <typename scalar_t>
|
||||
void colo_all_gather_impl(scalar_t *recvbuff, int rank, int sendcount,
|
||||
ncclDataType_t data_type) {
|
||||
scalar_t *sendbuff = recvbuff + (rank * sendcount);
|
||||
NCCLCHECK(ncclAllGather(sendbuff, recvbuff, sendcount, data_type,
|
||||
GMGR->nccl_comm, GMGR->cuda_stream));
|
||||
CUDACHECK(cudaStreamSynchronize(GMGR->cuda_stream));
|
||||
}
|
||||
|
||||
int colo_all_gather(torch::Tensor &input_tensor, int rank, int world_size) {
|
||||
CHECK_INPUT(input_tensor);
|
||||
|
||||
auto total_size = input_tensor.numel();
|
||||
assert(total_size % world_size == 0);
|
||||
auto sendcount = total_size / world_size;
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
||||
input_tensor.scalar_type(), "colo_all_gather", ([&] {
|
||||
colo_all_gather_impl<scalar_t>(
|
||||
input_tensor.data_ptr<scalar_t>(), rank, sendcount,
|
||||
input_tensor.scalar_type() == at::ScalarType::Half ? ncclHalf
|
||||
: ncclFloat);
|
||||
}));
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
#ifdef USE_C10D_NCCL
|
||||
m.def("create_comm", &create_zero_comm,
|
||||
"Create the communication environment for Colossal Zero");
|
||||
#endif
|
||||
m.def("inplace_all_gather", &colo_all_gather,
|
||||
"All gather operation used in Colossal Zero");
|
||||
}
|
|
@ -1 +0,0 @@
|
|||
from .zero_comm import ZeroDist
|
|
@ -1,46 +0,0 @@
|
|||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup
|
||||
from colossalai.context.singleton_meta import SingletonMeta
|
||||
from colossalai.utils import get_current_device
|
||||
from typing import Optional
|
||||
|
||||
ZERO_USE_NCCL = False
|
||||
try:
|
||||
import colossal_zero_comm
|
||||
ZERO_USE_NCCL = True
|
||||
except ImportError:
|
||||
print("Please pip reinstall Colossalai.")
|
||||
|
||||
|
||||
class ZeroCommWorld(metaclass=SingletonMeta):
|
||||
"""Zero communicator, used for communications in zero parallel.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.zero_pg: Optional[ProcessGroup] = None
|
||||
|
||||
@property
|
||||
def is_initialized(self):
|
||||
return self.zero_pg is not None
|
||||
|
||||
def zero_comm_init(self, comm_group: ProcessGroup):
|
||||
if not ZERO_USE_NCCL:
|
||||
return
|
||||
|
||||
if self.is_initialized:
|
||||
assert self.zero_pg == comm_group, "Cant not initialize zero group twice"
|
||||
return
|
||||
|
||||
self.zero_pg = comm_group
|
||||
colossal_zero_comm.create_comm(self.zero_pg, get_current_device())
|
||||
|
||||
def zero_all_gather(self, input_tensor: torch.Tensor):
|
||||
assert self.zero_pg is not None, "Please initialize zero communication world first"
|
||||
rank = dist.get_rank(self.zero_pg)
|
||||
world_size = self.zero_pg.size()
|
||||
colossal_zero_comm.inplace_all_gather(input_tensor, rank, world_size)
|
||||
|
||||
|
||||
ZeroDist = ZeroCommWorld()
|
|
@ -12,7 +12,6 @@ from colossalai.logging import get_dist_logger
|
|||
from colossalai.zero.shard_utils import BaseShardStrategy
|
||||
from colossalai.zero.sharded_model._utils import cast_tensor_to_fp16
|
||||
from colossalai.zero.sharded_param import ShardedParamV2
|
||||
from colossalai.zero.comm import ZeroDist
|
||||
from contextlib import AbstractContextManager
|
||||
|
||||
|
||||
|
@ -192,7 +191,6 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
|||
The Callback function when entering the context
|
||||
"""
|
||||
self.logger = get_dist_logger("ZeroInitContext")
|
||||
ZeroDist.zero_comm_init(self.dp_process_group) # initialize zero communication world
|
||||
|
||||
# substitute fan-in and fan-out calculation
|
||||
self.nn_fanin_fanout = nn.init._calculate_fan_in_and_fan_out
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
from .base_shard_strategy import BaseShardStrategy
|
||||
from .bucket_tensor_shard_strategy import BucketTensorShardStrategy
|
||||
from .tensor_shard_strategy import TensorShardStrategy
|
||||
from .zero_tensor_shard_strategy import ZeroTensorShardStrategy
|
||||
|
||||
__all__ = ['BaseShardStrategy', 'TensorShardStrategy', 'BucketTensorShardStrategy', 'ZeroTensorShardStrategy']
|
||||
__all__ = ['BaseShardStrategy', 'TensorShardStrategy', 'BucketTensorShardStrategy']
|
||||
|
|
|
@ -1,38 +0,0 @@
|
|||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.zero.sharded_param.tensor_utils import colo_model_data_tensor_move_inline
|
||||
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
|
||||
from colossalai.zero.comm import ZeroDist
|
||||
|
||||
from .tensor_shard_strategy import TensorShardStrategy
|
||||
|
||||
|
||||
class ZeroTensorShardStrategy(TensorShardStrategy):
|
||||
"""Use the same shard scheme as `TensorShardStrategy`'s.
|
||||
But its all-gather operation is in-place, meaning that no extra buffer is created.
|
||||
Extra buffer is created when using `torch.distributed.all_gather`.
|
||||
This can reduce peak memory used in zero-offload.
|
||||
You should notice that this strategy is highly coupled with zero.
|
||||
You can not change its communication group and must use ZeroContext to create your model.
|
||||
"""
|
||||
|
||||
def _gather_tensor(self, t: ShardedTensor, process_group: Optional[dist.ProcessGroup] = None):
|
||||
if not t.is_sharded:
|
||||
return
|
||||
target_device = t.device
|
||||
payload_numel = t.payload.numel()
|
||||
world_size = dist.get_world_size(process_group)
|
||||
rank = dist.get_rank(process_group)
|
||||
|
||||
buffer = torch.empty(payload_numel * world_size, dtype=t.payload.dtype, device=get_current_device())
|
||||
buffer_list = list(torch.chunk(buffer, chunks=world_size, dim=0))
|
||||
buffer_list[rank].copy_(t.payload)
|
||||
|
||||
ZeroDist.zero_all_gather(buffer) # notice: process_group is useless here
|
||||
gathered_payload = torch.narrow(buffer, 0, 0, t.origin_numel).reshape(t.origin_shape)
|
||||
t.reset_payload(gathered_payload)
|
||||
colo_model_data_tensor_move_inline(t, target_device)
|
||||
t.is_sharded = False
|
6
setup.py
6
setup.py
|
@ -134,12 +134,6 @@ if build_cuda_ext:
|
|||
'nvcc': append_nvcc_threads(['-O3', '--use_fast_math'] + version_dependent_macros + extra_cuda_flags)
|
||||
})
|
||||
|
||||
ext_modules.append(
|
||||
cuda_ext_helper(name='colossal_zero_comm',
|
||||
sources=['zero_comm.cpp'],
|
||||
extra_cuda_flags=['-DUSE_C10D_NCCL'],
|
||||
extra_cxx_flags=['-DUSE_C10D_NCCL']))
|
||||
|
||||
ext_modules.append(
|
||||
cuda_ext_helper('colossal_C', [
|
||||
'colossal_C_frontend.cpp', 'multi_tensor_sgd_kernel.cu', 'multi_tensor_scale_kernel.cu',
|
||||
|
|
|
@ -9,7 +9,7 @@ from colossalai.nn.optimizer import HybridAdam
|
|||
from colossalai.testing import parameterize, rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.zero.init_ctx import ZeroInitContext
|
||||
from colossalai.zero.shard_utils import ZeroTensorShardStrategy
|
||||
from colossalai.zero.shard_utils import BucketTensorShardStrategy
|
||||
from colossalai.zero.sharded_model import ShardedModelV2
|
||||
from colossalai.zero.sharded_optim import ShardedOptimizerV2
|
||||
from colossalai.zero.sharded_optim._utils import has_inf_or_nan
|
||||
|
@ -20,7 +20,7 @@ from common import CONFIG
|
|||
|
||||
|
||||
@parameterize("cpu_offload", [True, False])
|
||||
@parameterize("shard_strategy_class", [ZeroTensorShardStrategy])
|
||||
@parameterize("shard_strategy_class", [BucketTensorShardStrategy])
|
||||
@parameterize("gpu_margin_mem_ratio", [0.0, 0.7])
|
||||
def _run_test_found_inf(cpu_offload, shard_strategy_class, gpu_margin_mem_ratio):
|
||||
test_models = ['repeated_computed_layers']
|
||||
|
|
|
@ -15,14 +15,14 @@ from colossalai.gemini.memory_tracer.model_data_memtracer import \
|
|||
colo_model_mem_usage
|
||||
from colossalai.utils.memory import colo_device_memory_used
|
||||
from colossalai.zero.init_ctx import ZeroInitContext
|
||||
from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy, ZeroTensorShardStrategy)
|
||||
from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
|
||||
from common import CONFIG
|
||||
|
||||
|
||||
@parameterize("init_device_type", ['cpu', 'cuda'])
|
||||
@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy, ZeroTensorShardStrategy])
|
||||
@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy])
|
||||
def run_model_test(init_device_type, shard_strategy_class):
|
||||
logger = get_dist_logger("test_zero_init")
|
||||
|
||||
|
|
|
@ -8,7 +8,7 @@ from colossalai.utils.cuda import get_current_device
|
|||
from colossalai.utils.memory import colo_device_memory_capacity, colo_set_process_memory_fraction
|
||||
from colossalai.zero.init_ctx import ZeroInitContext
|
||||
from colossalai.zero.sharded_model import ShardedModelV2
|
||||
from colossalai.zero.shard_utils import ZeroTensorShardStrategy
|
||||
from colossalai.zero.shard_utils import BucketTensorShardStrategy
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from functools import partial
|
||||
|
@ -35,7 +35,7 @@ def run_mem_collector_testing():
|
|||
fraction = (50 * 1024**2) / cuda_capacity
|
||||
# limit max memory to 50MB
|
||||
colo_set_process_memory_fraction(fraction)
|
||||
shard_strategy = ZeroTensorShardStrategy()
|
||||
shard_strategy = BucketTensorShardStrategy()
|
||||
with ZeroInitContext(target_device=get_current_device(), shard_strategy=shard_strategy, shard_param=True):
|
||||
model = MyTestModel()
|
||||
|
||||
|
|
|
@ -10,7 +10,7 @@ import torch.multiprocessing as mp
|
|||
from colossalai.testing import parameterize, rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.zero.init_ctx import ZeroInitContext
|
||||
from colossalai.zero.shard_utils import (BucketTensorShardStrategy, ZeroTensorShardStrategy)
|
||||
from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
|
||||
from colossalai.zero.sharded_model import ShardedModelV2
|
||||
from colossalai.zero.sharded_model._utils import cast_tensor_to_fp16
|
||||
from colossalai.zero.sharded_model.utils import col_model_deepcopy
|
||||
|
@ -21,7 +21,7 @@ from common import CONFIG, check_grads_padding, run_fwd_bwd
|
|||
|
||||
|
||||
@parameterize("enable_autocast", [True])
|
||||
@parameterize("shard_strategy_class", [ZeroTensorShardStrategy, BucketTensorShardStrategy])
|
||||
@parameterize("shard_strategy_class", [BucketTensorShardStrategy])
|
||||
def run_model_test(enable_autocast, shard_strategy_class):
|
||||
test_models = ['repeated_computed_layers', 'resnet18', 'bert', 'no_leaf_module']
|
||||
shard_strategy = shard_strategy_class()
|
||||
|
|
|
@ -11,7 +11,7 @@ import torch.multiprocessing as mp
|
|||
from colossalai.testing import parameterize, rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.zero.init_ctx import ZeroInitContext
|
||||
from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy, ZeroTensorShardStrategy)
|
||||
from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
|
||||
from colossalai.zero.sharded_model import ShardedModelV2
|
||||
from colossalai.zero.sharded_model.utils import col_model_deepcopy
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
|
@ -19,7 +19,7 @@ from tests.components_to_test.registry import non_distributed_component_funcs
|
|||
from common import CONFIG
|
||||
|
||||
|
||||
@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy, ZeroTensorShardStrategy])
|
||||
@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy])
|
||||
def run_zero_state_dict(shard_strategy_class):
|
||||
test_models = ['repeated_computed_layers', 'resnet18']
|
||||
shard_strategy = shard_strategy_class()
|
||||
|
|
Loading…
Reference in New Issue