From e81caeb4bc20ed14be0dd5f52d14c0f11813c817 Mon Sep 17 00:00:00 2001 From: Xue Fuzhao <57164838+XueFuzhao@users.noreply.github.com> Date: Wed, 15 Feb 2023 16:12:45 +0800 Subject: [PATCH] [NFC] polish colossalai/auto_parallel/tensor_shard/deprecated/cost_graph.py code style (#2720) Co-authored-by: Fuzhao Xue --- .../tensor_shard/deprecated/cost_graph.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/colossalai/auto_parallel/tensor_shard/deprecated/cost_graph.py b/colossalai/auto_parallel/tensor_shard/deprecated/cost_graph.py index 239d02115..50220bca6 100644 --- a/colossalai/auto_parallel/tensor_shard/deprecated/cost_graph.py +++ b/colossalai/auto_parallel/tensor_shard/deprecated/cost_graph.py @@ -1,6 +1,8 @@ -from typing import List import math +from typing import List + from torch.fx.node import Node + from .constants import INFINITY_COST @@ -9,7 +11,7 @@ class CostGraph: A graph data structure to simplify the edge cost graph. It has two main functions: 1. To feed the quadratic resharding costs into solver, we need to linearize it. We build edge_cost in CostGraph, and it stored every combinations of strategies for a src-dst node pair in an 1D list. - 2. To reduce the searching space, we merge computationally-trivial operators, such as + 2. To reduce the searching space, we merge computationally-trivial operators, such as element-wise operators, transpose, and reduction, into their following nodes. The merging infomation will be given by the StrategiesVector depending on the type of target node and following nodes. @@ -75,14 +77,14 @@ class CostGraph: def merge_node(self, src_node, dst_node): ''' To merge dst_node into src_node, we need to do it in following steps: - + 1. For each strategy in dst_node, we need to pick an appropriate strategy - of src_node to merge, it is important because the logical resharding costs - between the parents node of src_node and merged node depend on the src_node + of src_node to merge, it is important because the logical resharding costs + between the parents node of src_node and merged node depend on the src_node strategies dispatching. For example, for the graph 0->1->2, after merging node 1 into node 2, edge_costs[(node 0, node 2)][(0, 0)] = edge_costs[(node 0, node 1)][(0, x)] x represents the picking strategy of node 1 merged into node 2 strategy 0. - + 2. We need to accumulate the extra costs introduced by merging nodes, the extra costs contains two parts, one is resharding costs between src_node strategy and dst_node strategy, another is the origin extra costs in src_node strategy.