diff --git a/colossalai/utils/activation_checkpoint.py b/colossalai/utils/activation_checkpoint.py index 2bc10efd8..88cc7e202 100644 --- a/colossalai/utils/activation_checkpoint.py +++ b/colossalai/utils/activation_checkpoint.py @@ -12,7 +12,7 @@ def copy_to_device(obj, device): if torch.is_tensor(obj): # Notice: # When in no_grad context, requires_gard is False after movement - ret = obj.to(device) + ret = obj.to(device).detach() ret.requires_grad = obj.requires_grad return ret elif isinstance(obj, list):