free param.grad

pull/433/head
ver217 3 years ago
parent 9506a8beb2
commit ea6905a898

@ -1,5 +1,5 @@
import functools
from ast import Try
from asyncio.log import logger
from collections import OrderedDict
from typing import Any, Optional
@ -21,7 +21,7 @@ from colossalai.zero.sharded_param import ShardedParamV2
from torch.distributed import ProcessGroup
from torch.nn.parameter import Parameter
from ._zero3_utils import (cast_float_arguments, cast_tensor_to_fp16, cast_tensor_to_fp32, chunk_and_pad,
from ._zero3_utils import (cast_float_arguments, cast_tensor_to_fp16, cast_tensor_to_fp32, chunk_and_pad, free_storage,
get_gradient_predivide_factor)
@ -218,6 +218,9 @@ class ShardedModelV2(nn.Module):
else:
self._reduce_scatter_callback(param, new_grad)
orig_grad_data.record_stream(self.comm_stream)
empty_grad = torch.empty_like(grad)
free_storage(empty_grad)
return empty_grad
def _reduce_scatter_callback(self, param: Parameter, reduced_grad: torch.Tensor) -> None:
if self.gradient_postdivide_factor > 1:

Loading…
Cancel
Save