mirror of https://github.com/hpcaitech/ColossalAI
[feature] fit non tensor broadcast (#6218)
parent
de282dd694
commit
2bb71c6248
|
@ -37,13 +37,22 @@ def ray_broadcast_tensor_dict(
|
|||
rank = cc.get_rank(group_name)
|
||||
if rank == src:
|
||||
metadata = []
|
||||
non_tensor_dict = {}
|
||||
for k, v in tensor_dict.items():
|
||||
metadata.append((k, v.shape, v.dtype))
|
||||
if isinstance(v, torch.Tensor):
|
||||
metadata.append((k, v.shape, v.dtype))
|
||||
else:
|
||||
non_tensor_dict[k] = v
|
||||
else:
|
||||
metadata = None
|
||||
metadata = ray_broadcast_object(metadata, src, device, group_name)
|
||||
non_tensor_dict = None
|
||||
|
||||
data_to_broadcast = (metadata, non_tensor_dict)
|
||||
data_to_broadcast = ray_broadcast_object(data_to_broadcast, src, device, group_name)
|
||||
metadata, non_tensor_dict = data_to_broadcast
|
||||
|
||||
if rank != src:
|
||||
out_dict = {}
|
||||
out_dict = non_tensor_dict
|
||||
for k, shape, dtype in metadata:
|
||||
if rank == src:
|
||||
tensor = tensor_dict[k]
|
||||
|
|
Loading…
Reference in New Issue