mirror of https://github.com/hpcaitech/ColossalAI
free param.grad
parent
9506a8beb2
commit
ea6905a898
|
@ -1,5 +1,5 @@
|
|||
import functools
|
||||
from ast import Try
|
||||
from asyncio.log import logger
|
||||
from collections import OrderedDict
|
||||
from typing import Any, Optional
|
||||
|
||||
|
@ -21,7 +21,7 @@ from colossalai.zero.sharded_param import ShardedParamV2
|
|||
from torch.distributed import ProcessGroup
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from ._zero3_utils import (cast_float_arguments, cast_tensor_to_fp16, cast_tensor_to_fp32, chunk_and_pad,
|
||||
from ._zero3_utils import (cast_float_arguments, cast_tensor_to_fp16, cast_tensor_to_fp32, chunk_and_pad, free_storage,
|
||||
get_gradient_predivide_factor)
|
||||
|
||||
|
||||
|
@ -218,6 +218,9 @@ class ShardedModelV2(nn.Module):
|
|||
else:
|
||||
self._reduce_scatter_callback(param, new_grad)
|
||||
orig_grad_data.record_stream(self.comm_stream)
|
||||
empty_grad = torch.empty_like(grad)
|
||||
free_storage(empty_grad)
|
||||
return empty_grad
|
||||
|
||||
def _reduce_scatter_callback(self, param: Parameter, reduced_grad: torch.Tensor) -> None:
|
||||
if self.gradient_postdivide_factor > 1:
|
||||
|
|
Loading…
Reference in New Issue