[hotfix] fix error for torch 2.0 (#2243)

pull/2254/head^2
xcnick 2022-12-30 23:11:55 +08:00 committed by GitHub
parent b7d0990c61
commit 85178a397a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 6 additions and 3 deletions

View File

@ -4,7 +4,8 @@
#include <torch/extension.h> #include <torch/extension.h>
#include <torch/torch.h> #include <torch/torch.h>
#if TORCH_VERSION_MINOR >= 13 #if TORCH_VERSION_MAJOR > 1 || \
(TORCH_VERSION_MAJOR == 1 && TORCH_VERSION_MINOR >= 13)
#include <torch/csrc/distributed/c10d/Types.hpp> #include <torch/csrc/distributed/c10d/Types.hpp>
#else #else
#include <c10d/Types.hpp> #include <c10d/Types.hpp>

View File

@ -6,7 +6,8 @@
#include <cuda_runtime_api.h> #include <cuda_runtime_api.h>
#include <torch/torch.h> #include <torch/torch.h>
#if TORCH_VERSION_MINOR >= 13 #if TORCH_VERSION_MAJOR > 1 || \
(TORCH_VERSION_MAJOR == 1 && TORCH_VERSION_MINOR >= 13)
#include <torch/csrc/distributed/c10d/ProcessGroup.hpp> #include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
#else #else
#include <c10d/ProcessGroup.hpp> #include <c10d/ProcessGroup.hpp>

View File

@ -69,6 +69,7 @@ class ColoTensor(torch.Tensor):
data (torch.Tensor): a torch tensor used as the payload the colotensor. data (torch.Tensor): a torch tensor used as the payload the colotensor.
spec (ColoTensorSpec, optional): the tensor spec of initialization. Defaults to ColoTensorSpec(ReplicaSpec()). spec (ColoTensorSpec, optional): the tensor spec of initialization. Defaults to ColoTensorSpec(ReplicaSpec()).
""" """
torch_major = int(torch.__version__.split('.')[0])
torch_minor = int(torch.__version__.split('.')[1]) torch_minor = int(torch.__version__.split('.')[1])
def __new__(cls, data: torch.Tensor, spec: ColoTensorSpec) -> 'ColoTensor': def __new__(cls, data: torch.Tensor, spec: ColoTensorSpec) -> 'ColoTensor':
@ -168,7 +169,7 @@ class ColoTensor(torch.Tensor):
if func in _COLOSSAL_OPS: if func in _COLOSSAL_OPS:
func = _COLOSSAL_OPS[func] func = _COLOSSAL_OPS[func]
if cls.torch_minor >= 12: if cls.torch_major > 1 or (cls.torch_major == 1 and cls.torch_minor >= 12):
# in order to trigger pre-op hook in the forward of checkpoint module # in order to trigger pre-op hook in the forward of checkpoint module
# we have to capture the `backward` function # we have to capture the `backward` function
# and make sure that it does not in `torch._C.DisableTorchFunction()` context # and make sure that it does not in `torch._C.DisableTorchFunction()` context