Revert "[NFC] polish code format" (#2372)

pull/2375/head^2
binmakeswell 2023-01-06 16:01:09 +08:00 committed by GitHub
parent 0dcc410f57
commit a881d6d000
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 27 additions and 30 deletions

View File

@ -1,6 +1,5 @@
import operator
import torch import torch
import operator
__all__ = [ __all__ = [
'ELEMENTWISE_MODULE_OP', 'ELEMENTWISE_FUNC_OP', 'RESHAPE_FUNC_OP', 'CONV_MODULE_OP', 'CONV_FUNC_OP', 'ELEMENTWISE_MODULE_OP', 'ELEMENTWISE_FUNC_OP', 'RESHAPE_FUNC_OP', 'CONV_MODULE_OP', 'CONV_FUNC_OP',

View File

@ -1,11 +1,9 @@
from collections import OrderedDict as ODict
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, List, OrderedDict, Union from torch.fx.node import Node
from torch.fx.graph import Graph from torch.fx.graph import Graph
from torch.fx.graph_module import GraphModule from torch.fx.graph_module import GraphModule
from torch.fx.node import Node from collections import OrderedDict as ODict
from typing import List, OrderedDict, Union, Any
from colossalai.fx.passes.utils import get_node_module from colossalai.fx.passes.utils import get_node_module
__all__ = ['LiveVariable', 'LiveVariableVector', 'LiveStage', 'GraphAnalyser'] __all__ = ['LiveVariable', 'LiveVariableVector', 'LiveStage', 'GraphAnalyser']

View File

@ -2,9 +2,9 @@ import operator
from functools import reduce from functools import reduce
import torch import torch
from colossalai.auto_parallel.tensor_shard.deprecated._utils import \
from colossalai.auto_parallel.tensor_shard.deprecated._utils import ignore_sharding_exception ignore_sharding_exception
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import (ShardingStrategy, StrategiesVector)
from .operator_handler import OperatorHandler from .operator_handler import OperatorHandler
@ -76,19 +76,19 @@ class BatchNormHandler(OperatorHandler):
Argument: Argument:
sharding_size_forward(int): The forward activation will be divided sharding_size_forward(int): The forward activation will be divided
into sharding_size_forward number partions. into sharding_size_forward number partions.
sharding_size_backward_activation(int): The backward activation will sharding_size_backward_activation(int): The backward activation will
be divided into sharding_size_backward_activation number partions. be divided into sharding_size_backward_activation number partions.
sharding_size_weight(int): The backward weight will be divided sharding_size_weight(int): The backward weight will be divided
into sharding_size_weight number partions. into sharding_size_weight number partions.
Return: Return:
memory_cost(Tuple[float]): Memory cost per device with this memory_cost(Tuple[float]): Memory cost per device with this
specific strategy, the first element of this tuple is forward specific strategy, the first element of this tuple is forward
memory cost, and the second element of this tuple is backward memory cost, and the second element of this tuple is backward
memory cost. memory cost.
memory_cost_forward(float): Memory cost of forward activation per memory_cost_forward(float): Memory cost of forward activation per
device with this specific strategy. device with this specific strategy.
memory_cost_backward_activation(float): Memory cost of backward activation memory_cost_backward_activation(float): Memory cost of backward activation
per device with this specific strategy. per device with this specific strategy.
''' '''
# compute the memory cost of this strategy # compute the memory cost of this strategy
@ -458,7 +458,7 @@ class BatchNormHandler(OperatorHandler):
norm_handler.register_strategy() norm_handler.register_strategy()
for strategy in norm_handler.strategies_vector: for strategy in norm_handler.strategies_vector:
print(f'{strategy.name}, computation_cost: {strategy.compute_cost}, memory_cost: {strategy.memory_cost}') print(f'{strategy.name}, computation_cost: {strategy.compute_cost}, memory_cost: {strategy.memory_cost}')
Output: Output:
RS0 = RS0 x S0, computation_cost: 131072, memory_cost: 524288.0 RS0 = RS0 x S0, computation_cost: 131072, memory_cost: 524288.0
RS1 = RS1 x S1, computation_cost: 131072, memory_cost: 524288.0 RS1 = RS1 x S1, computation_cost: 131072, memory_cost: 524288.0

View File

@ -5,9 +5,9 @@ from functools import reduce
from typing import Dict, List from typing import Dict, List
import torch import torch
from colossalai.auto_parallel.tensor_shard.deprecated._utils import \
from colossalai.auto_parallel.tensor_shard.deprecated._utils import ignore_sharding_exception ignore_sharding_exception
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import (ShardingStrategy, StrategiesVector)
from colossalai.tensor.shape_consistency import ShapeConsistencyManager from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.tensor.sharding_spec import ShardingSpec from colossalai.tensor.sharding_spec import ShardingSpec
@ -42,19 +42,19 @@ class EmbeddingHandler(OperatorHandler):
Argument: Argument:
sharding_size_forward(int): The forward activation will be divided sharding_size_forward(int): The forward activation will be divided
into sharding_size_forward number partions. into sharding_size_forward number partions.
sharding_size_backward_activation(int): The backward activation will sharding_size_backward_activation(int): The backward activation will
be divided into sharding_size_backward_activation number partions. be divided into sharding_size_backward_activation number partions.
sharding_size_weight(int): The backward weight will be divided sharding_size_weight(int): The backward weight will be divided
into sharding_size_weight number partions. into sharding_size_weight number partions.
Return: Return:
memory_cost(Tuple[float]): Memory cost per device with this memory_cost(Tuple[float]): Memory cost per device with this
specific strategy, the first element of this tuple is forward specific strategy, the first element of this tuple is forward
memory cost, and the second element of this tuple is backward memory cost, and the second element of this tuple is backward
memory cost. memory cost.
memory_cost_forward(float): Memory cost of forward activation per memory_cost_forward(float): Memory cost of forward activation per
device with this specific strategy. device with this specific strategy.
memory_cost_backward_activation(float): Memory cost of backward activation memory_cost_backward_activation(float): Memory cost of backward activation
per device with this specific strategy. per device with this specific strategy.
''' '''
# compute the memory cost of this strategy # compute the memory cost of this strategy

View File

@ -6,10 +6,11 @@ from functools import reduce
from typing import Dict, List from typing import Dict, List
import torch import torch
from colossalai.auto_parallel.tensor_shard.deprecated._utils import \
from colossalai.auto_parallel.tensor_shard.deprecated._utils import ignore_sharding_exception ignore_sharding_exception
from colossalai.auto_parallel.tensor_shard.deprecated.constants import INFINITY_COST from colossalai.auto_parallel.tensor_shard.deprecated.constants import \
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector INFINITY_COST
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import (ShardingStrategy, StrategiesVector)
from colossalai.tensor.shape_consistency import ShapeConsistencyManager from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.tensor.sharding_spec import ShardingSpec from colossalai.tensor.sharding_spec import ShardingSpec

View File

@ -4,11 +4,10 @@ from functools import reduce
from typing import Dict, List, Optional, Union from typing import Dict, List, Optional, Union
import torch import torch
from torch.fx.node import Node
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.shape_consistency import ShapeConsistencyManager from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.tensor.sharding_spec import ShardingSpec from colossalai.tensor.sharding_spec import ShardingSpec
from torch.fx.node import Node
from ..constants import INFINITY_COST from ..constants import INFINITY_COST
@ -19,7 +18,7 @@ def generate_sharding_spec(input_: Union[Node, torch.Tensor], device_mesh: Devic
dim_partition_dict: Dict[int, List[int]]) -> ShardingSpec: dim_partition_dict: Dict[int, List[int]]) -> ShardingSpec:
""" """
Generate the sharding spec of the tensor based on the given dim_partition_dict. Generate the sharding spec of the tensor based on the given dim_partition_dict.
Args: Args:
input_ (Union[Node, torch.Tensor]): the input can be a Node object or a PyTorch tensor. If a node is used, it will look for its meta data associated with this node. input_ (Union[Node, torch.Tensor]): the input can be a Node object or a PyTorch tensor. If a node is used, it will look for its meta data associated with this node.
@ -60,7 +59,7 @@ def generate_resharding_costs(nodes: List[Node],
nodes (List[Node]): a list of nodes nodes (List[Node]): a list of nodes
sharding_spec_for_input(ShardingSpec): a list of ShardingSpec for the nodes. sharding_spec_for_input(ShardingSpec): a list of ShardingSpec for the nodes.
count_backward (Optional[bool]): whether to include the cost of resharding in the backward pass, default is True. False can be used for inference. count_backward (Optional[bool]): whether to include the cost of resharding in the backward pass, default is True. False can be used for inference.
dtype (Optional[torch.dtype]): the data type for cost calculation, default is None. dtype (Optional[torch.dtype]): the data type for cost calculation, default is None.
''' '''
# The resharding_cost of weight is counted due to sharing weight cases. # The resharding_cost of weight is counted due to sharing weight cases.
resharding_costs = {} resharding_costs = {}