diff --git a/colossalai/zero/sharded_model/sharded_model_v2.py b/colossalai/zero/sharded_model/sharded_model_v2.py index 2de39de4a..594a14ec7 100644 --- a/colossalai/zero/sharded_model/sharded_model_v2.py +++ b/colossalai/zero/sharded_model/sharded_model_v2.py @@ -1,5 +1,5 @@ import functools -from ast import Try +from asyncio.log import logger from collections import OrderedDict from typing import Any, Optional @@ -21,7 +21,7 @@ from colossalai.zero.sharded_param import ShardedParamV2 from torch.distributed import ProcessGroup from torch.nn.parameter import Parameter -from ._zero3_utils import (cast_float_arguments, cast_tensor_to_fp16, cast_tensor_to_fp32, chunk_and_pad, +from ._zero3_utils import (cast_float_arguments, cast_tensor_to_fp16, cast_tensor_to_fp32, chunk_and_pad, free_storage, get_gradient_predivide_factor) @@ -218,6 +218,9 @@ class ShardedModelV2(nn.Module): else: self._reduce_scatter_callback(param, new_grad) orig_grad_data.record_stream(self.comm_stream) + empty_grad = torch.empty_like(grad) + free_storage(empty_grad) + return empty_grad def _reduce_scatter_callback(self, param: Parameter, reduced_grad: torch.Tensor) -> None: if self.gradient_postdivide_factor > 1: