[NFC] polish colossalai/auto_parallel/tensor_shard/deprecated/op_handler/layer_norm_handler.py code style (#2305)

pull/2317/head
Ofey Chan 2 years ago committed by Frank Lee
parent 116e3d0b8f
commit 87d2defda6

@ -2,10 +2,14 @@ import operator
from functools import reduce from functools import reduce
import torch import torch
from colossalai.auto_parallel.tensor_shard.deprecated._utils import (enumerate_all_possible_1d_sharding,
enumerate_all_possible_2d_sharding, from colossalai.auto_parallel.tensor_shard.deprecated._utils import (
generate_sharding_size, ignore_sharding_exception) enumerate_all_possible_1d_sharding,
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import (ShardingStrategy, StrategiesVector) enumerate_all_possible_2d_sharding,
generate_sharding_size,
ignore_sharding_exception,
)
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector
from .operator_handler import OperatorHandler from .operator_handler import OperatorHandler
@ -63,19 +67,19 @@ class LayerNormHandler(OperatorHandler):
Argument: Argument:
sharding_size_forward(int): The forward activation will be divided sharding_size_forward(int): The forward activation will be divided
into sharding_size_forward number partions. into sharding_size_forward number partions.
sharding_size_backward_activation(int): The backward activation will sharding_size_backward_activation(int): The backward activation will
be divided into sharding_size_backward_activation number partions. be divided into sharding_size_backward_activation number partions.
sharding_size_weight(int): The backward weight will be divided sharding_size_weight(int): The backward weight will be divided
into sharding_size_weight number partions. into sharding_size_weight number partions.
Return: Return:
memory_cost(Tuple[float]): Memory cost per device with this memory_cost(Tuple[float]): Memory cost per device with this
specific strategy, the first element of this tuple is forward specific strategy, the first element of this tuple is forward
memory cost, and the second element of this tuple is backward memory cost, and the second element of this tuple is backward
memory cost. memory cost.
memory_cost_forward(float): Memory cost of forward activation per memory_cost_forward(float): Memory cost of forward activation per
device with this specific strategy. device with this specific strategy.
memory_cost_backward_activation(float): Memory cost of backward activation memory_cost_backward_activation(float): Memory cost of backward activation
per device with this specific strategy. per device with this specific strategy.
''' '''
# compute the memory cost of this strategy # compute the memory cost of this strategy
@ -216,7 +220,7 @@ class LayerNormHandler(OperatorHandler):
norm_handler.register_strategy() norm_handler.register_strategy()
for strategy in norm_handler.strategies_vector: for strategy in norm_handler.strategies_vector:
print(f'{strategy.name}, computation_cost: {strategy.compute_cost}, memory_cost: {strategy.memory_cost}') print(f'{strategy.name}, computation_cost: {strategy.compute_cost}, memory_cost: {strategy.memory_cost}')
Output: Output:
RS0 = RS0 x S0, computation_cost: 131072, memory_cost: 524288.0 RS0 = RS0 x S0, computation_cost: 131072, memory_cost: 524288.0
RS1 = RS1 x S1, computation_cost: 131072, memory_cost: 524288.0 RS1 = RS1 x S1, computation_cost: 131072, memory_cost: 524288.0

Loading…
Cancel
Save