[hotfix] Dist Mgr gather torch version (#1284)

* make it faster

* [hotfix] torchvison fx tests

* [hotfix] rename duplicated named test_gpt.py

* [hotfix] dist mgr torch version
pull/1248/merge
Jiarui Fang 2 years ago committed by GitHub
parent 7e8114a8dd
commit 556b9b7e1a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -88,11 +88,13 @@ class DistSpecManager:
torch.Tensor: a replicated tensor.
"""
assert old_dist_spec.placement.value == 's', f"The old_dist_spec of DistSpecManager._gather must be SHARD!"
if version.parse(torch.__version__) < version.parse("1.11.0"):
is_cpu_tensor = False
if tensor.device.type == 'cpu':
# pytorch lower than 1.11 dose not support gather a cpu tensor.
# Therefore, we transfer tensor to GPU before gather.
saved_dev = tensor.device
tensor.data = tensor.data.cuda()
is_cpu_tensor = True
buffer = [torch.empty_like(tensor) for _ in range(pg.tp_world_size())]
assert tensor.device.type == 'cuda'
@ -106,7 +108,7 @@ class DistSpecManager:
buffer = new_buffer
assert len(buffer) == 1
if version.parse(torch.__version__) < version.parse("1.11.0"):
if is_cpu_tensor:
buffer[0].data = buffer[0].data.to(saved_dev)
return buffer[0]

Loading…
Cancel
Save