|
|
|
@ -303,41 +303,38 @@ class ShardedModelV2(nn.Module):
|
|
|
|
|
assert not grad.requires_grad, 'ShardedModel only works with gradients that don\'t require gradients'
|
|
|
|
|
if not self._require_backward_grad_sync:
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
# used to cheat Pytorch, since we can't return None
|
|
|
|
|
empty_grad = torch.empty_like(grad)
|
|
|
|
|
free_storage(empty_grad)
|
|
|
|
|
# As torch didn't allow modifying grad in hook, we make a copy
|
|
|
|
|
grad = grad.clone()
|
|
|
|
|
if param.colo_attr.is_replicated:
|
|
|
|
|
self._reduce_scatter_handler(param, grad)
|
|
|
|
|
else:
|
|
|
|
|
self._save_grad(param, grad)
|
|
|
|
|
|
|
|
|
|
# used to cheat Pytorch, since we can't return None
|
|
|
|
|
empty_grad = torch.empty_like(grad)
|
|
|
|
|
free_storage(empty_grad)
|
|
|
|
|
return empty_grad
|
|
|
|
|
|
|
|
|
|
def _reduce_scatter_handler(self, param: Parameter, grad: torch.Tensor) -> None:
|
|
|
|
|
self.comm_stream.wait_stream(torch.cuda.current_stream())
|
|
|
|
|
with torch.cuda.stream(self.comm_stream):
|
|
|
|
|
new_grad = grad.clone()
|
|
|
|
|
if self.fp32_reduce_scatter:
|
|
|
|
|
new_grad.data = new_grad.data.to(param.dtype)
|
|
|
|
|
grad.data = grad.data.to(param.dtype)
|
|
|
|
|
if self.gradient_predivide_factor > 1.0:
|
|
|
|
|
# Average grad by world_size for consistency with PyTorch DDP.
|
|
|
|
|
new_grad.data.div_(self.gradient_predivide_factor)
|
|
|
|
|
orig_grad_data = new_grad.data
|
|
|
|
|
grad.data.div_(self.gradient_predivide_factor)
|
|
|
|
|
if self.world_size > 1:
|
|
|
|
|
grad_chunks = chunk_and_pad(orig_grad_data, self.reduce_scatter_process_group.size())
|
|
|
|
|
grad_chunks = chunk_and_pad(grad, self.reduce_scatter_process_group.size())
|
|
|
|
|
self.reducer.reduce_scatter_async(grad_chunks,
|
|
|
|
|
group=self.reduce_scatter_process_group,
|
|
|
|
|
callback_fn=functools.partial(self._reduce_scatter_callback, param))
|
|
|
|
|
else:
|
|
|
|
|
self._reduce_scatter_callback(param, new_grad)
|
|
|
|
|
orig_grad_data.record_stream(self.comm_stream)
|
|
|
|
|
self._reduce_scatter_callback(param, grad)
|
|
|
|
|
torch.cuda.current_stream().wait_stream(self.comm_stream)
|
|
|
|
|
|
|
|
|
|
def _reduce_scatter_callback(self, param: Parameter, reduced_grad: torch.Tensor) -> None:
|
|
|
|
|
assert isinstance(reduced_grad,
|
|
|
|
|
torch.Tensor), f"_reduce_scatter_callback accept reduced_grad as {type(reduced_grad)}"
|
|
|
|
|
reduced_grad = reduced_grad.view(-1)
|
|
|
|
|
reduced_grad.data = reduced_grad.data.view(-1)
|
|
|
|
|
if self.gradient_postdivide_factor > 1:
|
|
|
|
|
# Average grad by world_size for consistency with PyTorch DDP.
|
|
|
|
|
reduced_grad.data.div_(self.gradient_postdivide_factor)
|
|
|
|
|