polish code

pull/394/head
jiaruifang 2022-03-09 12:09:07 +08:00 committed by Frank Lee
parent d271f2596b
commit 5663616921
1 changed files with 2 additions and 4 deletions

View File

@ -18,8 +18,7 @@ from torch.distributed import ProcessGroup
from torch.nn.parameter import Parameter
from ._zero3_utils import (cast_tensor_to_fp32, chunk_and_pad, get_gradient_predivide_factor)
# from ._zero3_utils import cast_float_arguments, cast_tensor_to_fp16
from ._zero3_utils import (cast_float_arguments, cast_tensor_to_fp16)
class ShardedModelV2(nn.Module):
@ -80,8 +79,7 @@ class ShardedModelV2(nn.Module):
self._require_backward_grad_sync: bool = True
def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor:
# TODO args can be Long!
# args, kwargs = cast_float_arguments(cast_tensor_to_fp16, *args, **kwargs)
args, kwargs = cast_float_arguments(cast_tensor_to_fp16, *args, **kwargs)
outputs = self.module(*args, **kwargs)
return outputs