[feature] fit non tensor broadcast (#6218)

feature/ray-rlhf
Hongxin Liu 2025-02-24 14:36:04 +08:00 committed by GitHub
parent de282dd694
commit 2bb71c6248
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 12 additions and 3 deletions

View File

@ -37,13 +37,22 @@ def ray_broadcast_tensor_dict(
rank = cc.get_rank(group_name)
if rank == src:
metadata = []
non_tensor_dict = {}
for k, v in tensor_dict.items():
metadata.append((k, v.shape, v.dtype))
if isinstance(v, torch.Tensor):
metadata.append((k, v.shape, v.dtype))
else:
non_tensor_dict[k] = v
else:
metadata = None
metadata = ray_broadcast_object(metadata, src, device, group_name)
non_tensor_dict = None
data_to_broadcast = (metadata, non_tensor_dict)
data_to_broadcast = ray_broadcast_object(data_to_broadcast, src, device, group_name)
metadata, non_tensor_dict = data_to_broadcast
if rank != src:
out_dict = {}
out_dict = non_tensor_dict
for k, shape, dtype in metadata:
if rank == src:
tensor = tensor_dict[k]