From 556b9b7e1ad045c995014bb2c039b23d4c63f251 Mon Sep 17 00:00:00 2001 From: Jiarui Fang Date: Wed, 13 Jul 2022 00:18:56 +0800 Subject: [PATCH] [hotfix] Dist Mgr gather torch version (#1284) * make it faster * [hotfix] torchvison fx tests * [hotfix] rename duplicated named test_gpt.py * [hotfix] dist mgr torch version --- colossalai/tensor/dist_spec_mgr.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/colossalai/tensor/dist_spec_mgr.py b/colossalai/tensor/dist_spec_mgr.py index 51b7bfb91..f1dc241a8 100644 --- a/colossalai/tensor/dist_spec_mgr.py +++ b/colossalai/tensor/dist_spec_mgr.py @@ -88,11 +88,13 @@ class DistSpecManager: torch.Tensor: a replicated tensor. """ assert old_dist_spec.placement.value == 's', f"The old_dist_spec of DistSpecManager._gather must be SHARD!" - if version.parse(torch.__version__) < version.parse("1.11.0"): + is_cpu_tensor = False + if tensor.device.type == 'cpu': # pytorch lower than 1.11 dose not support gather a cpu tensor. # Therefore, we transfer tensor to GPU before gather. saved_dev = tensor.device tensor.data = tensor.data.cuda() + is_cpu_tensor = True buffer = [torch.empty_like(tensor) for _ in range(pg.tp_world_size())] assert tensor.device.type == 'cuda' @@ -106,7 +108,7 @@ class DistSpecManager: buffer = new_buffer assert len(buffer) == 1 - if version.parse(torch.__version__) < version.parse("1.11.0"): + if is_cpu_tensor: buffer[0].data = buffer[0].data.to(saved_dev) return buffer[0]