diff --git a/colossalai/kernel/cuda_native/csrc/multihead_attention_1d.cpp b/colossalai/kernel/cuda_native/csrc/multihead_attention_1d.cpp index 166c698f6..d08f3dbc7 100644 --- a/colossalai/kernel/cuda_native/csrc/multihead_attention_1d.cpp +++ b/colossalai/kernel/cuda_native/csrc/multihead_attention_1d.cpp @@ -4,7 +4,8 @@ #include #include -#if TORCH_VERSION_MINOR >= 13 +#if TORCH_VERSION_MAJOR > 1 || \ + (TORCH_VERSION_MAJOR == 1 && TORCH_VERSION_MINOR >= 13) #include #else #include diff --git a/colossalai/kernel/cuda_native/csrc/multihead_attention_1d.h b/colossalai/kernel/cuda_native/csrc/multihead_attention_1d.h index db50071b6..6505eb31f 100644 --- a/colossalai/kernel/cuda_native/csrc/multihead_attention_1d.h +++ b/colossalai/kernel/cuda_native/csrc/multihead_attention_1d.h @@ -6,7 +6,8 @@ #include #include -#if TORCH_VERSION_MINOR >= 13 +#if TORCH_VERSION_MAJOR > 1 || \ + (TORCH_VERSION_MAJOR == 1 && TORCH_VERSION_MINOR >= 13) #include #else #include diff --git a/colossalai/tensor/colo_tensor.py b/colossalai/tensor/colo_tensor.py index 670c210e3..93ab982cc 100644 --- a/colossalai/tensor/colo_tensor.py +++ b/colossalai/tensor/colo_tensor.py @@ -69,6 +69,7 @@ class ColoTensor(torch.Tensor): data (torch.Tensor): a torch tensor used as the payload the colotensor. spec (ColoTensorSpec, optional): the tensor spec of initialization. Defaults to ColoTensorSpec(ReplicaSpec()). """ + torch_major = int(torch.__version__.split('.')[0]) torch_minor = int(torch.__version__.split('.')[1]) def __new__(cls, data: torch.Tensor, spec: ColoTensorSpec) -> 'ColoTensor': @@ -168,7 +169,7 @@ class ColoTensor(torch.Tensor): if func in _COLOSSAL_OPS: func = _COLOSSAL_OPS[func] - if cls.torch_minor >= 12: + if cls.torch_major > 1 or (cls.torch_major == 1 and cls.torch_minor >= 12): # in order to trigger pre-op hook in the forward of checkpoint module # we have to capture the `backward` function # and make sure that it does not in `torch._C.DisableTorchFunction()` context