fix object_to_tensor usage when torch>=2.3.0 (#5820)

pull/5912/head
アマデウス 2024-07-04 10:53:58 +08:00 committed by Hongxin Liu
parent 2e28c793ce
commit 530283dba0
1 changed files with 10 additions and 2 deletions

View File

@ -91,7 +91,11 @@ def _broadcast_object_list(
my_rank = dist.get_rank()
# Serialize object_list elements to tensors on src rank.
if my_rank == src:
if Version(torch.__version__) >= Version("1.13.0"):
if Version(torch.__version__) >= Version("2.3.0"):
tensor_list, size_list = zip(
*[c10d._object_to_tensor(obj, device=current_device, group=group) for obj in object_list]
)
elif Version(torch.__version__) >= Version("1.13.0"):
tensor_list, size_list = zip(*[c10d._object_to_tensor(obj, device=current_device) for obj in object_list])
else:
tensor_list, size_list = zip(*[c10d._object_to_tensor(obj) for obj in object_list])
@ -276,7 +280,11 @@ def _send_recv_serialization_object(
send_object_tensor = None
send_object_size_tensor = None
if object is not None and send_dst is not None:
if Version(torch.__version__) >= Version("1.13.0"):
if Version(torch.__version__) >= Version("2.3.0"):
send_object_tensor, send_object_size_tensor = c10d._object_to_tensor(
object, device=current_device, group=send_group
)
elif Version(torch.__version__) >= Version("1.13.0"):
send_object_tensor, send_object_size_tensor = c10d._object_to_tensor(object, device=current_device)
else:
send_object_tensor, send_object_size_tensor = c10d._object_to_tensor(object)