[lazyinit] fix clone and deepcopy (#3553)

pull/3579/head
Hongxin Liu 2 years ago committed by GitHub
parent 1c7734bc94
commit 4341f5e8e6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -14,8 +14,8 @@ from colossalai.tensor.d_tensor.layout import Layout
# reference: https://pytorch.org/cppdocs/notes/tensor_creation.html
_NORMAL_FACTORY = [
"arange",
"empty",
"full",
"empty",
"linspace",
"logspace",
"ones",
@ -324,7 +324,9 @@ class LazyTensor(torch.Tensor):
def clone(self) -> "LazyTensor":
def factory_fn():
return self.materialize().clone()
# if self is materialized, return self
new_tensor = self.materialize() if type(self) is LazyTensor else self
return new_tensor.clone()
target = LazyTensor(factory_fn, meta_data=self._meta_data)
@ -333,6 +335,26 @@ class LazyTensor(torch.Tensor):
def detach(self) -> Tensor:
return self
def __deepcopy__(self, memo):
if not self.is_leaf:
raise RuntimeError("Only Tensors created explicitly by the user "
"(graph leaves) support the deepcopy protocol at the moment")
if id(self) in memo:
return memo[id(self)]
def factory_fn():
# if self is materialized, return self
new_tensor = self.materialize() if type(self) is LazyTensor else self
copied = new_tensor.detach().clone()
if new_tensor.requires_grad:
copied.requires_grad_()
return copied
target = LazyTensor(factory_fn, meta_data=self._meta_data)
memo[id(self)] = target
return target
@property
def data(self):
return self

Loading…
Cancel
Save