mirror of https://github.com/hpcaitech/ColossalAI
[hotfix] fix error for torch 2.0 (#2243)
parent
b7d0990c61
commit
85178a397a
|
@ -4,7 +4,8 @@
|
||||||
#include <torch/extension.h>
|
#include <torch/extension.h>
|
||||||
#include <torch/torch.h>
|
#include <torch/torch.h>
|
||||||
|
|
||||||
#if TORCH_VERSION_MINOR >= 13
|
#if TORCH_VERSION_MAJOR > 1 || \
|
||||||
|
(TORCH_VERSION_MAJOR == 1 && TORCH_VERSION_MINOR >= 13)
|
||||||
#include <torch/csrc/distributed/c10d/Types.hpp>
|
#include <torch/csrc/distributed/c10d/Types.hpp>
|
||||||
#else
|
#else
|
||||||
#include <c10d/Types.hpp>
|
#include <c10d/Types.hpp>
|
||||||
|
|
|
@ -6,7 +6,8 @@
|
||||||
#include <cuda_runtime_api.h>
|
#include <cuda_runtime_api.h>
|
||||||
#include <torch/torch.h>
|
#include <torch/torch.h>
|
||||||
|
|
||||||
#if TORCH_VERSION_MINOR >= 13
|
#if TORCH_VERSION_MAJOR > 1 || \
|
||||||
|
(TORCH_VERSION_MAJOR == 1 && TORCH_VERSION_MINOR >= 13)
|
||||||
#include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
|
#include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
|
||||||
#else
|
#else
|
||||||
#include <c10d/ProcessGroup.hpp>
|
#include <c10d/ProcessGroup.hpp>
|
||||||
|
|
|
@ -69,6 +69,7 @@ class ColoTensor(torch.Tensor):
|
||||||
data (torch.Tensor): a torch tensor used as the payload the colotensor.
|
data (torch.Tensor): a torch tensor used as the payload the colotensor.
|
||||||
spec (ColoTensorSpec, optional): the tensor spec of initialization. Defaults to ColoTensorSpec(ReplicaSpec()).
|
spec (ColoTensorSpec, optional): the tensor spec of initialization. Defaults to ColoTensorSpec(ReplicaSpec()).
|
||||||
"""
|
"""
|
||||||
|
torch_major = int(torch.__version__.split('.')[0])
|
||||||
torch_minor = int(torch.__version__.split('.')[1])
|
torch_minor = int(torch.__version__.split('.')[1])
|
||||||
|
|
||||||
def __new__(cls, data: torch.Tensor, spec: ColoTensorSpec) -> 'ColoTensor':
|
def __new__(cls, data: torch.Tensor, spec: ColoTensorSpec) -> 'ColoTensor':
|
||||||
|
@ -168,7 +169,7 @@ class ColoTensor(torch.Tensor):
|
||||||
if func in _COLOSSAL_OPS:
|
if func in _COLOSSAL_OPS:
|
||||||
func = _COLOSSAL_OPS[func]
|
func = _COLOSSAL_OPS[func]
|
||||||
|
|
||||||
if cls.torch_minor >= 12:
|
if cls.torch_major > 1 or (cls.torch_major == 1 and cls.torch_minor >= 12):
|
||||||
# in order to trigger pre-op hook in the forward of checkpoint module
|
# in order to trigger pre-op hook in the forward of checkpoint module
|
||||||
# we have to capture the `backward` function
|
# we have to capture the `backward` function
|
||||||
# and make sure that it does not in `torch._C.DisableTorchFunction()` context
|
# and make sure that it does not in `torch._C.DisableTorchFunction()` context
|
||||||
|
|
Loading…
Reference in New Issue