diff --git a/colossalai/gemini/tensor_placement_policy.py b/colossalai/gemini/tensor_placement_policy.py index cfcfb3856..0e575254c 100644 --- a/colossalai/gemini/tensor_placement_policy.py +++ b/colossalai/gemini/tensor_placement_policy.py @@ -1,16 +1,16 @@ +import functools from abc import ABC, abstractmethod from time import time -from typing import List, Optional +from typing import List, Optional, Type + import torch + +from colossalai.gemini.memory_tracer import MemStatsCollector +from colossalai.gemini.stateful_tensor import StatefulTensor +from colossalai.gemini.tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage from colossalai.utils import get_current_device from colossalai.utils.memory import colo_device_memory_capacity -from colossalai.gemini.tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage -from colossalai.gemini.stateful_tensor import StatefulTensor -from colossalai.gemini.memory_tracer import MemStatsCollector -from typing import Type -import functools - class TensorPlacementPolicy(ABC):