[hotfix] fix memory leak in backward of sharded model (#741)

pull/742/head
ver217 2022-04-13 09:59:05 +08:00 committed by GitHub
parent f4f42d4c3c
commit e6212f56cd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 10 additions and 13 deletions

View File

@ -303,41 +303,38 @@ class ShardedModelV2(nn.Module):
assert not grad.requires_grad, 'ShardedModel only works with gradients that don\'t require gradients'
if not self._require_backward_grad_sync:
return
# used to cheat Pytorch, since we can't return None
empty_grad = torch.empty_like(grad)
free_storage(empty_grad)
# As torch didn't allow modifying grad in hook, we make a copy
grad = grad.clone()
if param.colo_attr.is_replicated:
self._reduce_scatter_handler(param, grad)
else:
self._save_grad(param, grad)
# used to cheat Pytorch, since we can't return None
empty_grad = torch.empty_like(grad)
free_storage(empty_grad)
return empty_grad
def _reduce_scatter_handler(self, param: Parameter, grad: torch.Tensor) -> None:
self.comm_stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(self.comm_stream):
new_grad = grad.clone()
if self.fp32_reduce_scatter:
new_grad.data = new_grad.data.to(param.dtype)
grad.data = grad.data.to(param.dtype)
if self.gradient_predivide_factor > 1.0:
# Average grad by world_size for consistency with PyTorch DDP.
new_grad.data.div_(self.gradient_predivide_factor)
orig_grad_data = new_grad.data
grad.data.div_(self.gradient_predivide_factor)
if self.world_size > 1:
grad_chunks = chunk_and_pad(orig_grad_data, self.reduce_scatter_process_group.size())
grad_chunks = chunk_and_pad(grad, self.reduce_scatter_process_group.size())
self.reducer.reduce_scatter_async(grad_chunks,
group=self.reduce_scatter_process_group,
callback_fn=functools.partial(self._reduce_scatter_callback, param))
else:
self._reduce_scatter_callback(param, new_grad)
orig_grad_data.record_stream(self.comm_stream)
self._reduce_scatter_callback(param, grad)
torch.cuda.current_stream().wait_stream(self.comm_stream)
def _reduce_scatter_callback(self, param: Parameter, reduced_grad: torch.Tensor) -> None:
assert isinstance(reduced_grad,
torch.Tensor), f"_reduce_scatter_callback accept reduced_grad as {type(reduced_grad)}"
reduced_grad = reduced_grad.view(-1)
reduced_grad.data = reduced_grad.data.view(-1)
if self.gradient_postdivide_factor > 1:
# Average grad by world_size for consistency with PyTorch DDP.
reduced_grad.data.div_(self.gradient_postdivide_factor)