[NFC] polish colossalai/builder/pipeline.py code style (#638)

pull/673/head
Ziheng Qin 2022-04-02 14:22:41 +08:00 committed by binmakeswell
parent 10591ecdf9
commit c7c224ee17
1 changed files with 6 additions and 3 deletions

View File

@ -1,7 +1,6 @@
import copy import copy
import heapq import heapq
from colossalai.builder import build_model, build_layer from colossalai.builder import build_model, build_layer
from colossalai.context.parallel_mode import ParallelMode from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
@ -40,6 +39,7 @@ def _binary_partition(weights, st, ed):
def _heap_addition(weights, intervals, add_cnt): def _heap_addition(weights, intervals, add_cnt):
""" """
""" """
def _heap_push(heap, st, ed): def _heap_push(heap, st, ed):
value = weights[ed - 1] value = weights[ed - 1]
if st > 0: if st > 0:
@ -162,7 +162,10 @@ def count_layer_params(layers):
return param_counts return param_counts
def build_pipeline_model_from_cfg(config, num_chunks: int = 1, partition_method: str = 'parameter', verbose: bool = False): def build_pipeline_model_from_cfg(config,
num_chunks: int = 1,
partition_method: str = 'parameter',
verbose: bool = False):
"""An initializer to split the model into different stages for pipeline parallelism. """An initializer to split the model into different stages for pipeline parallelism.
An example for the model config is shown below. The class VisionTransformerFromConfig should An example for the model config is shown below. The class VisionTransformerFromConfig should
@ -218,7 +221,7 @@ def build_pipeline_model_from_cfg(config, num_chunks: int = 1, partition_method:
log_str += f'\n===== stage={stage}, layers={num_layers} =====\n' log_str += f'\n===== stage={stage}, layers={num_layers} =====\n'
for st, ed in parts[stage]: for st, ed in parts[stage]:
for idx, layer in enumerate(layers[st: ed]): for idx, layer in enumerate(layers[st:ed]):
log_str += f'\t{idx + st:2d}: {layer}\n' log_str += f'\t{idx + st:2d}: {layer}\n'
logger.info(log_str, ranks=[0]) logger.info(log_str, ranks=[0])