[NFC] polish colossalai/zero/sharded_model/reduce_scatter.py code style (#1554)

pull/1550/head
Fazzie-Maqianli 2022-09-08 16:56:20 +08:00 committed by Frank Lee
parent 2ac46f7be4
commit 06dccdde44
1 changed files with 13 additions and 13 deletions

View File

@ -20,6 +20,7 @@ else:
class Bucket:
def __init__(self, shard_size: int, dtype: torch.dtype, device: torch.device, group: ProcessGroup):
self.buffer = torch.zeros((group.size(), shard_size), dtype=dtype, device=device)
self.group = group
@ -34,13 +35,13 @@ class Bucket:
return
# reduce-scatter bucket
if hasattr(dist, "_reduce_scatter_base") and enable_nccl_base_collectives:
dist._reduce_scatter_base(
self.output_shard[: self.offset], self.buffer[:, : self.offset].contiguous(), group=self.group
)
dist._reduce_scatter_base(self.output_shard[:self.offset],
self.buffer[:, :self.offset].contiguous(),
group=self.group)
else:
dist.reduce_scatter(
self.output_shard[: self.offset], list(self.buffer[:, : self.offset].unbind(0)), group=self.group
)
dist.reduce_scatter(self.output_shard[:self.offset],
list(self.buffer[:, :self.offset].unbind(0)),
group=self.group)
# execute post-reduction callbacks
for callback_fn in self.callbacks:
callback_fn()
@ -141,8 +142,7 @@ class ReduceScatterBucketer:
"""
world_size = group.size()
assert (
len(input_list) == world_size
assert (len(input_list) == world_size
), f"reduce_scatter received {len(input_list)} inputs, expected group.size() ({world_size})"
first_input = input_list[0]