Browse Source

[hotfix] meta tensor default device. (#2510)

pull/2527/head
Super Daniel 2 years ago committed by GitHub
parent
commit
c198c7c0b0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 2
      colossalai/fx/profiler/tensor.py

2
colossalai/fx/profiler/tensor.py

@ -43,7 +43,7 @@ class MetaTensor(torch.Tensor):
storage_offset=elem.storage_offset(), storage_offset=elem.storage_offset(),
dtype=elem.dtype, dtype=elem.dtype,
layout=elem.layout, layout=elem.layout,
device=fake_device if fake_device is not None else torch.device('cpu'), device=fake_device or (elem.device if elem.device.type != 'meta' else torch.device('cpu')),
requires_grad=elem.requires_grad) # deceive the frontend for aten selections requires_grad=elem.requires_grad) # deceive the frontend for aten selections
r._tensor = elem r._tensor = elem
# ...the real tensor is held as an element on the tensor. # ...the real tensor is held as an element on the tensor.

Loading…
Cancel
Save