mirror of https://github.com/hpcaitech/ColossalAI
[NFC] polish colossalai/auto_parallel/offload/amp_optimizer.py code style (#4255)
parent
85774f0c1f
commit
c614a99d28
|
@ -1,24 +1,25 @@
|
|||
from typing import Dict, Tuple
|
||||
from enum import Enum
|
||||
from typing import Dict, Tuple
|
||||
|
||||
import torch
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.nn.optimizer import ColossalaiOptimizer
|
||||
from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
from .base_offload_module import BaseOffloadModule
|
||||
from .region_manager import RegionManager
|
||||
from .region import Region
|
||||
from .region_manager import RegionManager
|
||||
|
||||
|
||||
class OptimState(Enum):
|
||||
SCALED = 0
|
||||
UNSCALED = 1
|
||||
|
||||
class AMPOptimizer(ColossalaiOptimizer):
|
||||
|
||||
class AMPOptimizer(ColossalaiOptimizer):
|
||||
"""
|
||||
A wrapper for Optimizer.
|
||||
Code reference: https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/optimizer/zero_optimizer.py
|
||||
|
@ -174,4 +175,4 @@ class AMPOptimizer(ColossalaiOptimizer):
|
|||
|
||||
# Leverage state_dict() and load_state_dict() to
|
||||
# recast preexisting per-param state tensors
|
||||
self.optim.load_state_dict(self.optim.state_dict())
|
||||
self.optim.load_state_dict(self.optim.state_dict())
|
||||
|
|
Loading…
Reference in New Issue