diff --git a/colossalai/tensor/chunk.py b/colossalai/tensor/chunk.py index b3cb07328..4cacb0c7b 100644 --- a/colossalai/tensor/chunk.py +++ b/colossalai/tensor/chunk.py @@ -455,7 +455,7 @@ class ChunkManager: chunk (Chunk): the chunk to move to target device device (torch.device): target device """ - if chunk.data.device == device: + if chunk.device_type == device.type: return if chunk.can_move_device and not chunk.is_empty: self.total_mem[chunk.device_type] -= chunk.mem