diff --git a/colossalai/zero/sharded_model/sharded_model_v2.py b/colossalai/zero/sharded_model/sharded_model_v2.py index 028d0854c..5c5c2c421 100644 --- a/colossalai/zero/sharded_model/sharded_model_v2.py +++ b/colossalai/zero/sharded_model/sharded_model_v2.py @@ -303,41 +303,38 @@ class ShardedModelV2(nn.Module): assert not grad.requires_grad, 'ShardedModel only works with gradients that don\'t require gradients' if not self._require_backward_grad_sync: return - + # used to cheat Pytorch, since we can't return None + empty_grad = torch.empty_like(grad) + free_storage(empty_grad) + # As torch didn't allow modifying grad in hook, we make a copy + grad = grad.clone() if param.colo_attr.is_replicated: self._reduce_scatter_handler(param, grad) else: self._save_grad(param, grad) - - # used to cheat Pytorch, since we can't return None - empty_grad = torch.empty_like(grad) - free_storage(empty_grad) return empty_grad def _reduce_scatter_handler(self, param: Parameter, grad: torch.Tensor) -> None: self.comm_stream.wait_stream(torch.cuda.current_stream()) with torch.cuda.stream(self.comm_stream): - new_grad = grad.clone() if self.fp32_reduce_scatter: - new_grad.data = new_grad.data.to(param.dtype) + grad.data = grad.data.to(param.dtype) if self.gradient_predivide_factor > 1.0: # Average grad by world_size for consistency with PyTorch DDP. - new_grad.data.div_(self.gradient_predivide_factor) - orig_grad_data = new_grad.data + grad.data.div_(self.gradient_predivide_factor) if self.world_size > 1: - grad_chunks = chunk_and_pad(orig_grad_data, self.reduce_scatter_process_group.size()) + grad_chunks = chunk_and_pad(grad, self.reduce_scatter_process_group.size()) self.reducer.reduce_scatter_async(grad_chunks, group=self.reduce_scatter_process_group, callback_fn=functools.partial(self._reduce_scatter_callback, param)) else: - self._reduce_scatter_callback(param, new_grad) - orig_grad_data.record_stream(self.comm_stream) + self._reduce_scatter_callback(param, grad) torch.cuda.current_stream().wait_stream(self.comm_stream) def _reduce_scatter_callback(self, param: Parameter, reduced_grad: torch.Tensor) -> None: assert isinstance(reduced_grad, torch.Tensor), f"_reduce_scatter_callback accept reduced_grad as {type(reduced_grad)}" - reduced_grad = reduced_grad.view(-1) + reduced_grad.data = reduced_grad.data.view(-1) if self.gradient_postdivide_factor > 1: # Average grad by world_size for consistency with PyTorch DDP. reduced_grad.data.div_(self.gradient_postdivide_factor)