[NFC] polish batch_norm_handler.py code style (#2359)

pull/2367/head
ExtremeViscent 2023-01-06 13:41:38 +08:00 committed by GitHub
parent e11a005c02
commit ac0d30fe2e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 8 additions and 8 deletions

View File

@ -2,9 +2,9 @@ import operator
from functools import reduce from functools import reduce
import torch import torch
from colossalai.auto_parallel.tensor_shard.deprecated._utils import \
ignore_sharding_exception from colossalai.auto_parallel.tensor_shard.deprecated._utils import ignore_sharding_exception
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import (ShardingStrategy, StrategiesVector) from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector
from .operator_handler import OperatorHandler from .operator_handler import OperatorHandler
@ -76,19 +76,19 @@ class BatchNormHandler(OperatorHandler):
Argument: Argument:
sharding_size_forward(int): The forward activation will be divided sharding_size_forward(int): The forward activation will be divided
into sharding_size_forward number partions. into sharding_size_forward number partions.
sharding_size_backward_activation(int): The backward activation will sharding_size_backward_activation(int): The backward activation will
be divided into sharding_size_backward_activation number partions. be divided into sharding_size_backward_activation number partions.
sharding_size_weight(int): The backward weight will be divided sharding_size_weight(int): The backward weight will be divided
into sharding_size_weight number partions. into sharding_size_weight number partions.
Return: Return:
memory_cost(Tuple[float]): Memory cost per device with this memory_cost(Tuple[float]): Memory cost per device with this
specific strategy, the first element of this tuple is forward specific strategy, the first element of this tuple is forward
memory cost, and the second element of this tuple is backward memory cost, and the second element of this tuple is backward
memory cost. memory cost.
memory_cost_forward(float): Memory cost of forward activation per memory_cost_forward(float): Memory cost of forward activation per
device with this specific strategy. device with this specific strategy.
memory_cost_backward_activation(float): Memory cost of backward activation memory_cost_backward_activation(float): Memory cost of backward activation
per device with this specific strategy. per device with this specific strategy.
''' '''
# compute the memory cost of this strategy # compute the memory cost of this strategy
@ -458,7 +458,7 @@ class BatchNormHandler(OperatorHandler):
norm_handler.register_strategy() norm_handler.register_strategy()
for strategy in norm_handler.strategies_vector: for strategy in norm_handler.strategies_vector:
print(f'{strategy.name}, computation_cost: {strategy.compute_cost}, memory_cost: {strategy.memory_cost}') print(f'{strategy.name}, computation_cost: {strategy.compute_cost}, memory_cost: {strategy.memory_cost}')
Output: Output:
RS0 = RS0 x S0, computation_cost: 131072, memory_cost: 524288.0 RS0 = RS0 x S0, computation_cost: 131072, memory_cost: 524288.0
RS1 = RS1 x S1, computation_cost: 131072, memory_cost: 524288.0 RS1 = RS1 x S1, computation_cost: 131072, memory_cost: 524288.0