[NFC] polish colossalai/auto_parallel/offload/amp_optimizer.py code style (#4255)

pull/4338/head
Yanjia0 2023-07-18 10:54:55 +08:00 committed by binmakeswell
parent 85774f0c1f
commit c614a99d28
1 changed files with 6 additions and 5 deletions

View File

@ -1,24 +1,25 @@
from typing import Dict, Tuple
from enum import Enum
from typing import Dict, Tuple
import torch
from torch.optim import Optimizer
from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import ColossalaiOptimizer
from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
from colossalai.utils import get_current_device
from .base_offload_module import BaseOffloadModule
from .region_manager import RegionManager
from .region import Region
from .region_manager import RegionManager
class OptimState(Enum):
SCALED = 0
UNSCALED = 1
class AMPOptimizer(ColossalaiOptimizer):
class AMPOptimizer(ColossalaiOptimizer):
"""
A wrapper for Optimizer.
Code reference: https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/optimizer/zero_optimizer.py
@ -174,4 +175,4 @@ class AMPOptimizer(ColossalaiOptimizer):
# Leverage state_dict() and load_state_dict() to
# recast preexisting per-param state tensors
self.optim.load_state_dict(self.optim.state_dict())
self.optim.load_state_dict(self.optim.state_dict())