|
|
|
@ -1,5 +1,5 @@
|
|
|
|
|
import functools
|
|
|
|
|
from ast import Try
|
|
|
|
|
from asyncio.log import logger
|
|
|
|
|
from collections import OrderedDict
|
|
|
|
|
from typing import Any, Optional
|
|
|
|
|
|
|
|
|
@ -21,7 +21,7 @@ from colossalai.zero.sharded_param import ShardedParamV2
|
|
|
|
|
from torch.distributed import ProcessGroup
|
|
|
|
|
from torch.nn.parameter import Parameter
|
|
|
|
|
|
|
|
|
|
from ._zero3_utils import (cast_float_arguments, cast_tensor_to_fp16, cast_tensor_to_fp32, chunk_and_pad,
|
|
|
|
|
from ._zero3_utils import (cast_float_arguments, cast_tensor_to_fp16, cast_tensor_to_fp32, chunk_and_pad, free_storage,
|
|
|
|
|
get_gradient_predivide_factor)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -218,6 +218,9 @@ class ShardedModelV2(nn.Module):
|
|
|
|
|
else:
|
|
|
|
|
self._reduce_scatter_callback(param, new_grad)
|
|
|
|
|
orig_grad_data.record_stream(self.comm_stream)
|
|
|
|
|
empty_grad = torch.empty_like(grad)
|
|
|
|
|
free_storage(empty_grad)
|
|
|
|
|
return empty_grad
|
|
|
|
|
|
|
|
|
|
def _reduce_scatter_callback(self, param: Parameter, reduced_grad: torch.Tensor) -> None:
|
|
|
|
|
if self.gradient_postdivide_factor > 1:
|
|
|
|
|