mirror of https://github.com/hpcaitech/ColossalAI
[hotfix] fix memory leak in backward of sharded model (#741)
parent
f4f42d4c3c
commit
e6212f56cd
|
@ -303,41 +303,38 @@ class ShardedModelV2(nn.Module):
|
|||
assert not grad.requires_grad, 'ShardedModel only works with gradients that don\'t require gradients'
|
||||
if not self._require_backward_grad_sync:
|
||||
return
|
||||
|
||||
# used to cheat Pytorch, since we can't return None
|
||||
empty_grad = torch.empty_like(grad)
|
||||
free_storage(empty_grad)
|
||||
# As torch didn't allow modifying grad in hook, we make a copy
|
||||
grad = grad.clone()
|
||||
if param.colo_attr.is_replicated:
|
||||
self._reduce_scatter_handler(param, grad)
|
||||
else:
|
||||
self._save_grad(param, grad)
|
||||
|
||||
# used to cheat Pytorch, since we can't return None
|
||||
empty_grad = torch.empty_like(grad)
|
||||
free_storage(empty_grad)
|
||||
return empty_grad
|
||||
|
||||
def _reduce_scatter_handler(self, param: Parameter, grad: torch.Tensor) -> None:
|
||||
self.comm_stream.wait_stream(torch.cuda.current_stream())
|
||||
with torch.cuda.stream(self.comm_stream):
|
||||
new_grad = grad.clone()
|
||||
if self.fp32_reduce_scatter:
|
||||
new_grad.data = new_grad.data.to(param.dtype)
|
||||
grad.data = grad.data.to(param.dtype)
|
||||
if self.gradient_predivide_factor > 1.0:
|
||||
# Average grad by world_size for consistency with PyTorch DDP.
|
||||
new_grad.data.div_(self.gradient_predivide_factor)
|
||||
orig_grad_data = new_grad.data
|
||||
grad.data.div_(self.gradient_predivide_factor)
|
||||
if self.world_size > 1:
|
||||
grad_chunks = chunk_and_pad(orig_grad_data, self.reduce_scatter_process_group.size())
|
||||
grad_chunks = chunk_and_pad(grad, self.reduce_scatter_process_group.size())
|
||||
self.reducer.reduce_scatter_async(grad_chunks,
|
||||
group=self.reduce_scatter_process_group,
|
||||
callback_fn=functools.partial(self._reduce_scatter_callback, param))
|
||||
else:
|
||||
self._reduce_scatter_callback(param, new_grad)
|
||||
orig_grad_data.record_stream(self.comm_stream)
|
||||
self._reduce_scatter_callback(param, grad)
|
||||
torch.cuda.current_stream().wait_stream(self.comm_stream)
|
||||
|
||||
def _reduce_scatter_callback(self, param: Parameter, reduced_grad: torch.Tensor) -> None:
|
||||
assert isinstance(reduced_grad,
|
||||
torch.Tensor), f"_reduce_scatter_callback accept reduced_grad as {type(reduced_grad)}"
|
||||
reduced_grad = reduced_grad.view(-1)
|
||||
reduced_grad.data = reduced_grad.data.view(-1)
|
||||
if self.gradient_postdivide_factor > 1:
|
||||
# Average grad by world_size for consistency with PyTorch DDP.
|
||||
reduced_grad.data.div_(self.gradient_postdivide_factor)
|
||||
|
|
Loading…
Reference in New Issue