From 530283dba034b20c8f3562a661995e38926f3e80 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E3=82=A2=E3=83=9E=E3=83=87=E3=82=A6=E3=82=B9?= Date: Thu, 4 Jul 2024 10:53:58 +0800 Subject: [PATCH] fix object_to_tensor usage when torch>=2.3.0 (#5820) --- colossalai/pipeline/p2p.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/colossalai/pipeline/p2p.py b/colossalai/pipeline/p2p.py index ed190eb08..b7b284213 100644 --- a/colossalai/pipeline/p2p.py +++ b/colossalai/pipeline/p2p.py @@ -91,7 +91,11 @@ def _broadcast_object_list( my_rank = dist.get_rank() # Serialize object_list elements to tensors on src rank. if my_rank == src: - if Version(torch.__version__) >= Version("1.13.0"): + if Version(torch.__version__) >= Version("2.3.0"): + tensor_list, size_list = zip( + *[c10d._object_to_tensor(obj, device=current_device, group=group) for obj in object_list] + ) + elif Version(torch.__version__) >= Version("1.13.0"): tensor_list, size_list = zip(*[c10d._object_to_tensor(obj, device=current_device) for obj in object_list]) else: tensor_list, size_list = zip(*[c10d._object_to_tensor(obj) for obj in object_list]) @@ -276,7 +280,11 @@ def _send_recv_serialization_object( send_object_tensor = None send_object_size_tensor = None if object is not None and send_dst is not None: - if Version(torch.__version__) >= Version("1.13.0"): + if Version(torch.__version__) >= Version("2.3.0"): + send_object_tensor, send_object_size_tensor = c10d._object_to_tensor( + object, device=current_device, group=send_group + ) + elif Version(torch.__version__) >= Version("1.13.0"): send_object_tensor, send_object_size_tensor = c10d._object_to_tensor(object, device=current_device) else: send_object_tensor, send_object_size_tensor = c10d._object_to_tensor(object)