polish code

pull/403/head
ver217 2022-03-14 15:48:55 +08:00
parent 54fd37f0e0
commit 63469c0f91
1 changed files with 3 additions and 0 deletions

View File

@ -23,6 +23,9 @@ class BucketTensorShardStrategy(TensorShardStrategy):
for i in range(self.world_size): for i in range(self.world_size):
if i == self.local_rank: if i == self.local_rank:
buffer_list.append(flatten([t.payload for t in tensor_list]).cuda(get_current_device())) buffer_list.append(flatten([t.payload for t in tensor_list]).cuda(get_current_device()))
# Release payload here, to decrease peak memory usage
for t in tensor_list:
t.reset_payload(None)
else: else:
buffer_list.append(torch.zeros(buffer_size, dtype=dtype, device=get_current_device())) buffer_list.append(torch.zeros(buffer_size, dtype=dtype, device=get_current_device()))
dist.all_gather(buffer_list, buffer_list[self.local_rank], group=self.process_group) dist.all_gather(buffer_list, buffer_list[self.local_rank], group=self.process_group)