diff --git a/colossalai/auto_parallel/tensor_shard/deprecated/graph_analysis.py b/colossalai/auto_parallel/tensor_shard/deprecated/graph_analysis.py index 831e7eadd..9f7a6a5ec 100644 --- a/colossalai/auto_parallel/tensor_shard/deprecated/graph_analysis.py +++ b/colossalai/auto_parallel/tensor_shard/deprecated/graph_analysis.py @@ -1,9 +1,11 @@ +from collections import OrderedDict as ODict from dataclasses import dataclass -from torch.fx.node import Node +from typing import Any, List, OrderedDict, Union + from torch.fx.graph import Graph from torch.fx.graph_module import GraphModule -from collections import OrderedDict as ODict -from typing import List, OrderedDict, Union, Any +from torch.fx.node import Node + from colossalai.fx.passes.utils import get_node_module __all__ = ['LiveVariable', 'LiveVariableVector', 'LiveStage', 'GraphAnalyser']