mirror of https://github.com/hpcaitech/ColossalAI
[hotfix] Dist Mgr gather torch version (#1284)
* make it faster * [hotfix] torchvison fx tests * [hotfix] rename duplicated named test_gpt.py * [hotfix] dist mgr torch versionpull/1248/merge
parent
7e8114a8dd
commit
556b9b7e1a
|
@ -88,11 +88,13 @@ class DistSpecManager:
|
||||||
torch.Tensor: a replicated tensor.
|
torch.Tensor: a replicated tensor.
|
||||||
"""
|
"""
|
||||||
assert old_dist_spec.placement.value == 's', f"The old_dist_spec of DistSpecManager._gather must be SHARD!"
|
assert old_dist_spec.placement.value == 's', f"The old_dist_spec of DistSpecManager._gather must be SHARD!"
|
||||||
if version.parse(torch.__version__) < version.parse("1.11.0"):
|
is_cpu_tensor = False
|
||||||
|
if tensor.device.type == 'cpu':
|
||||||
# pytorch lower than 1.11 dose not support gather a cpu tensor.
|
# pytorch lower than 1.11 dose not support gather a cpu tensor.
|
||||||
# Therefore, we transfer tensor to GPU before gather.
|
# Therefore, we transfer tensor to GPU before gather.
|
||||||
saved_dev = tensor.device
|
saved_dev = tensor.device
|
||||||
tensor.data = tensor.data.cuda()
|
tensor.data = tensor.data.cuda()
|
||||||
|
is_cpu_tensor = True
|
||||||
|
|
||||||
buffer = [torch.empty_like(tensor) for _ in range(pg.tp_world_size())]
|
buffer = [torch.empty_like(tensor) for _ in range(pg.tp_world_size())]
|
||||||
assert tensor.device.type == 'cuda'
|
assert tensor.device.type == 'cuda'
|
||||||
|
@ -106,7 +108,7 @@ class DistSpecManager:
|
||||||
buffer = new_buffer
|
buffer = new_buffer
|
||||||
assert len(buffer) == 1
|
assert len(buffer) == 1
|
||||||
|
|
||||||
if version.parse(torch.__version__) < version.parse("1.11.0"):
|
if is_cpu_tensor:
|
||||||
buffer[0].data = buffer[0].data.to(saved_dev)
|
buffer[0].data = buffer[0].data.to(saved_dev)
|
||||||
return buffer[0]
|
return buffer[0]
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue