From 56636169216f2c49dda0684434dfd95e7667c5e9 Mon Sep 17 00:00:00 2001 From: jiaruifang Date: Wed, 9 Mar 2022 12:09:07 +0800 Subject: [PATCH] polish code --- colossalai/zero/sharded_model/sharded_model_v2.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/colossalai/zero/sharded_model/sharded_model_v2.py b/colossalai/zero/sharded_model/sharded_model_v2.py index ccbec95d8..3531488db 100644 --- a/colossalai/zero/sharded_model/sharded_model_v2.py +++ b/colossalai/zero/sharded_model/sharded_model_v2.py @@ -18,8 +18,7 @@ from torch.distributed import ProcessGroup from torch.nn.parameter import Parameter from ._zero3_utils import (cast_tensor_to_fp32, chunk_and_pad, get_gradient_predivide_factor) - -# from ._zero3_utils import cast_float_arguments, cast_tensor_to_fp16 +from ._zero3_utils import (cast_float_arguments, cast_tensor_to_fp16) class ShardedModelV2(nn.Module): @@ -80,8 +79,7 @@ class ShardedModelV2(nn.Module): self._require_backward_grad_sync: bool = True def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor: - # TODO args can be Long! - # args, kwargs = cast_float_arguments(cast_tensor_to_fp16, *args, **kwargs) + args, kwargs = cast_float_arguments(cast_tensor_to_fp16, *args, **kwargs) outputs = self.module(*args, **kwargs) return outputs