[NFC] polish colossalai/amp/naive_amp/_fp16_optimizer.py code style (#1819)

pull/1849/head
Fazzie-Maqianli 2022-11-08 15:07:02 +08:00 committed by binmakeswell
parent 9623ec1b02
commit 399f84d8f6
1 changed files with 9 additions and 7 deletions

View File

@ -9,14 +9,16 @@ try:
except:
print('Colossalai should be built with cuda extension to use the FP16 optimizer')
from torch.optim import Optimizer
from colossalai.core import global_context as gpc
from colossalai.context import ParallelMode
from colossalai.logging import get_dist_logger
from colossalai.utils import (copy_tensor_parallel_attributes, clip_grad_norm_fp32, multi_tensor_applier)
from torch.distributed import ProcessGroup
from .grad_scaler import BaseGradScaler
from torch.optim import Optimizer
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.logging import get_dist_logger
from colossalai.utils import clip_grad_norm_fp32, copy_tensor_parallel_attributes, multi_tensor_applier
from ._utils import has_inf_or_nan, zero_gard_by_list
from .grad_scaler import BaseGradScaler
__all__ = ['FP16Optimizer']
@ -41,7 +43,7 @@ def _multi_tensor_copy_this_to_that(this, that, overflow_buf=None):
class FP16Optimizer(Optimizer):
"""Float16 optimizer for fp16 and bf16 data types.
Args:
optimizer (torch.optim.Optimizer): base optimizer such as Adam or SGD
grad_scaler (BaseGradScaler): grad scaler for gradient chose in