[NFC] polish colossalai/fx/passes/split_module.py code style (#3263)

Co-authored-by: csric <richcsr256@gmail.com>
pull/3313/head
CsRic 2023-03-27 22:03:29 +08:00 committed by binmakeswell
parent 488f37048c
commit 00778abc48
1 changed files with 10 additions and 11 deletions

View File

@ -1,9 +1,10 @@
import torch
from torch.fx.graph_module import GraphModule
from typing import Callable, List, Dict, Any, Optional
from torch.fx._compatibility import compatibility
from packaging import version
import inspect
from typing import Any, Callable, Dict, List, Optional
import torch
from packaging import version
from torch.fx._compatibility import compatibility
from torch.fx.graph_module import GraphModule
@compatibility(is_backward_compatible=True)
@ -38,7 +39,7 @@ def split_module(
m: GraphModule,
root_m: torch.nn.Module,
split_callback: Callable[[torch.fx.node.Node], int],
merge_output = False,
merge_output=False,
):
"""
Adapted from https://github.com/pytorch/pytorch/blob/master/torch/fx/passes/split_module.py
@ -132,10 +133,8 @@ def split_module(
use_partition.inputs.setdefault(def_node.name)
if def_partition_name is not None:
use_partition.partitions_dependent_on.setdefault(def_partition_name)
def record_output(
def_node: torch.fx.node.Node, use_node: Optional[torch.fx.node.Node]
): # noqa: B950
def record_output(def_node: torch.fx.node.Node, use_node: Optional[torch.fx.node.Node]): # noqa: B950
def_partition_name = getattr(def_node, "_fx_partition", None)
use_partition_name = getattr(use_node, "_fx_partition", None)
if def_partition_name != use_partition_name:
@ -291,7 +290,7 @@ def split_module(
for partition_name in sorted_partitions:
partition = partitions[partition_name]
new_gm = torch.fx.graph_module.GraphModule(base_mod_attrs, base_mod_graph)
return new_gm