mirror of https://github.com/hpcaitech/ColossalAI
[lazyinit] fix clone and deepcopy (#3553)
parent
1c7734bc94
commit
4341f5e8e6
|
@ -14,8 +14,8 @@ from colossalai.tensor.d_tensor.layout import Layout
|
||||||
# reference: https://pytorch.org/cppdocs/notes/tensor_creation.html
|
# reference: https://pytorch.org/cppdocs/notes/tensor_creation.html
|
||||||
_NORMAL_FACTORY = [
|
_NORMAL_FACTORY = [
|
||||||
"arange",
|
"arange",
|
||||||
"empty",
|
|
||||||
"full",
|
"full",
|
||||||
|
"empty",
|
||||||
"linspace",
|
"linspace",
|
||||||
"logspace",
|
"logspace",
|
||||||
"ones",
|
"ones",
|
||||||
|
@ -324,7 +324,9 @@ class LazyTensor(torch.Tensor):
|
||||||
def clone(self) -> "LazyTensor":
|
def clone(self) -> "LazyTensor":
|
||||||
|
|
||||||
def factory_fn():
|
def factory_fn():
|
||||||
return self.materialize().clone()
|
# if self is materialized, return self
|
||||||
|
new_tensor = self.materialize() if type(self) is LazyTensor else self
|
||||||
|
return new_tensor.clone()
|
||||||
|
|
||||||
target = LazyTensor(factory_fn, meta_data=self._meta_data)
|
target = LazyTensor(factory_fn, meta_data=self._meta_data)
|
||||||
|
|
||||||
|
@ -333,6 +335,26 @@ class LazyTensor(torch.Tensor):
|
||||||
def detach(self) -> Tensor:
|
def detach(self) -> Tensor:
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
def __deepcopy__(self, memo):
|
||||||
|
if not self.is_leaf:
|
||||||
|
raise RuntimeError("Only Tensors created explicitly by the user "
|
||||||
|
"(graph leaves) support the deepcopy protocol at the moment")
|
||||||
|
if id(self) in memo:
|
||||||
|
return memo[id(self)]
|
||||||
|
|
||||||
|
def factory_fn():
|
||||||
|
# if self is materialized, return self
|
||||||
|
new_tensor = self.materialize() if type(self) is LazyTensor else self
|
||||||
|
copied = new_tensor.detach().clone()
|
||||||
|
if new_tensor.requires_grad:
|
||||||
|
copied.requires_grad_()
|
||||||
|
return copied
|
||||||
|
|
||||||
|
target = LazyTensor(factory_fn, meta_data=self._meta_data)
|
||||||
|
|
||||||
|
memo[id(self)] = target
|
||||||
|
return target
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def data(self):
|
def data(self):
|
||||||
return self
|
return self
|
||||||
|
|
Loading…
Reference in New Issue