diff --git a/colossalai/fx/passes/split_module.py b/colossalai/fx/passes/split_module.py index bc257edc8..9bc4bf1f5 100644 --- a/colossalai/fx/passes/split_module.py +++ b/colossalai/fx/passes/split_module.py @@ -1,9 +1,10 @@ -import torch -from torch.fx.graph_module import GraphModule -from typing import Callable, List, Dict, Any, Optional -from torch.fx._compatibility import compatibility -from packaging import version import inspect +from typing import Any, Callable, Dict, List, Optional + +import torch +from packaging import version +from torch.fx._compatibility import compatibility +from torch.fx.graph_module import GraphModule @compatibility(is_backward_compatible=True) @@ -38,7 +39,7 @@ def split_module( m: GraphModule, root_m: torch.nn.Module, split_callback: Callable[[torch.fx.node.Node], int], - merge_output = False, + merge_output=False, ): """ Adapted from https://github.com/pytorch/pytorch/blob/master/torch/fx/passes/split_module.py @@ -132,10 +133,8 @@ def split_module( use_partition.inputs.setdefault(def_node.name) if def_partition_name is not None: use_partition.partitions_dependent_on.setdefault(def_partition_name) - - def record_output( - def_node: torch.fx.node.Node, use_node: Optional[torch.fx.node.Node] - ): # noqa: B950 + + def record_output(def_node: torch.fx.node.Node, use_node: Optional[torch.fx.node.Node]): # noqa: B950 def_partition_name = getattr(def_node, "_fx_partition", None) use_partition_name = getattr(use_node, "_fx_partition", None) if def_partition_name != use_partition_name: @@ -291,7 +290,7 @@ def split_module( for partition_name in sorted_partitions: partition = partitions[partition_name] - + new_gm = torch.fx.graph_module.GraphModule(base_mod_attrs, base_mod_graph) return new_gm