mirror of https://github.com/hpcaitech/ColossalAI
fix object_to_tensor usage when torch>=2.3.0 (#5820)
parent
2e28c793ce
commit
530283dba0
|
@ -91,7 +91,11 @@ def _broadcast_object_list(
|
|||
my_rank = dist.get_rank()
|
||||
# Serialize object_list elements to tensors on src rank.
|
||||
if my_rank == src:
|
||||
if Version(torch.__version__) >= Version("1.13.0"):
|
||||
if Version(torch.__version__) >= Version("2.3.0"):
|
||||
tensor_list, size_list = zip(
|
||||
*[c10d._object_to_tensor(obj, device=current_device, group=group) for obj in object_list]
|
||||
)
|
||||
elif Version(torch.__version__) >= Version("1.13.0"):
|
||||
tensor_list, size_list = zip(*[c10d._object_to_tensor(obj, device=current_device) for obj in object_list])
|
||||
else:
|
||||
tensor_list, size_list = zip(*[c10d._object_to_tensor(obj) for obj in object_list])
|
||||
|
@ -276,7 +280,11 @@ def _send_recv_serialization_object(
|
|||
send_object_tensor = None
|
||||
send_object_size_tensor = None
|
||||
if object is not None and send_dst is not None:
|
||||
if Version(torch.__version__) >= Version("1.13.0"):
|
||||
if Version(torch.__version__) >= Version("2.3.0"):
|
||||
send_object_tensor, send_object_size_tensor = c10d._object_to_tensor(
|
||||
object, device=current_device, group=send_group
|
||||
)
|
||||
elif Version(torch.__version__) >= Version("1.13.0"):
|
||||
send_object_tensor, send_object_size_tensor = c10d._object_to_tensor(object, device=current_device)
|
||||
else:
|
||||
send_object_tensor, send_object_size_tensor = c10d._object_to_tensor(object)
|
||||
|
|
Loading…
Reference in New Issue