|
|
|
@ -69,6 +69,7 @@ class ColoTensor(torch.Tensor):
|
|
|
|
|
data (torch.Tensor): a torch tensor used as the payload the colotensor. |
|
|
|
|
spec (ColoTensorSpec, optional): the tensor spec of initialization. Defaults to ColoTensorSpec(ReplicaSpec()). |
|
|
|
|
""" |
|
|
|
|
torch_major = int(torch.__version__.split('.')[0]) |
|
|
|
|
torch_minor = int(torch.__version__.split('.')[1]) |
|
|
|
|
|
|
|
|
|
def __new__(cls, data: torch.Tensor, spec: ColoTensorSpec) -> 'ColoTensor': |
|
|
|
@ -168,7 +169,7 @@ class ColoTensor(torch.Tensor):
|
|
|
|
|
if func in _COLOSSAL_OPS: |
|
|
|
|
func = _COLOSSAL_OPS[func] |
|
|
|
|
|
|
|
|
|
if cls.torch_minor >= 12: |
|
|
|
|
if cls.torch_major > 1 or (cls.torch_major == 1 and cls.torch_minor >= 12): |
|
|
|
|
# in order to trigger pre-op hook in the forward of checkpoint module |
|
|
|
|
# we have to capture the `backward` function |
|
|
|
|
# and make sure that it does not in `torch._C.DisableTorchFunction()` context |
|
|
|
|