diff --git a/colossalai/fx/profiler/tensor.py b/colossalai/fx/profiler/tensor.py index 7606f17cf..2ee5e5c47 100644 --- a/colossalai/fx/profiler/tensor.py +++ b/colossalai/fx/profiler/tensor.py @@ -43,7 +43,7 @@ class MetaTensor(torch.Tensor): storage_offset=elem.storage_offset(), dtype=elem.dtype, layout=elem.layout, - device=fake_device if fake_device is not None else torch.device('cpu'), + device=fake_device or (elem.device if elem.device.type != 'meta' else torch.device('cpu')), requires_grad=elem.requires_grad) # deceive the frontend for aten selections r._tensor = elem # ...the real tensor is held as an element on the tensor.