mirror of https://github.com/hpcaitech/ColossalAI
[hotfix] Dist Mgr gather torch version (#1284)
* make it faster * [hotfix] torchvison fx tests * [hotfix] rename duplicated named test_gpt.py * [hotfix] dist mgr torch versionpull/1248/merge
parent
7e8114a8dd
commit
556b9b7e1a
|
@ -88,11 +88,13 @@ class DistSpecManager:
|
|||
torch.Tensor: a replicated tensor.
|
||||
"""
|
||||
assert old_dist_spec.placement.value == 's', f"The old_dist_spec of DistSpecManager._gather must be SHARD!"
|
||||
if version.parse(torch.__version__) < version.parse("1.11.0"):
|
||||
is_cpu_tensor = False
|
||||
if tensor.device.type == 'cpu':
|
||||
# pytorch lower than 1.11 dose not support gather a cpu tensor.
|
||||
# Therefore, we transfer tensor to GPU before gather.
|
||||
saved_dev = tensor.device
|
||||
tensor.data = tensor.data.cuda()
|
||||
is_cpu_tensor = True
|
||||
|
||||
buffer = [torch.empty_like(tensor) for _ in range(pg.tp_world_size())]
|
||||
assert tensor.device.type == 'cuda'
|
||||
|
@ -106,7 +108,7 @@ class DistSpecManager:
|
|||
buffer = new_buffer
|
||||
assert len(buffer) == 1
|
||||
|
||||
if version.parse(torch.__version__) < version.parse("1.11.0"):
|
||||
if is_cpu_tensor:
|
||||
buffer[0].data = buffer[0].data.to(saved_dev)
|
||||
return buffer[0]
|
||||
|
||||
|
|
Loading…
Reference in New Issue