Browse Source

[NFC] polish colossalai/builder/pipeline.py code style (#638)

pull/673/head
Ziheng Qin 3 years ago committed by binmakeswell
parent
commit
c7c224ee17
  1. 9
      colossalai/builder/pipeline.py

9
colossalai/builder/pipeline.py

@ -1,7 +1,6 @@
import copy import copy
import heapq import heapq
from colossalai.builder import build_model, build_layer from colossalai.builder import build_model, build_layer
from colossalai.context.parallel_mode import ParallelMode from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
@ -40,6 +39,7 @@ def _binary_partition(weights, st, ed):
def _heap_addition(weights, intervals, add_cnt): def _heap_addition(weights, intervals, add_cnt):
""" """
""" """
def _heap_push(heap, st, ed): def _heap_push(heap, st, ed):
value = weights[ed - 1] value = weights[ed - 1]
if st > 0: if st > 0:
@ -162,7 +162,10 @@ def count_layer_params(layers):
return param_counts return param_counts
def build_pipeline_model_from_cfg(config, num_chunks: int = 1, partition_method: str = 'parameter', verbose: bool = False): def build_pipeline_model_from_cfg(config,
num_chunks: int = 1,
partition_method: str = 'parameter',
verbose: bool = False):
"""An initializer to split the model into different stages for pipeline parallelism. """An initializer to split the model into different stages for pipeline parallelism.
An example for the model config is shown below. The class VisionTransformerFromConfig should An example for the model config is shown below. The class VisionTransformerFromConfig should
@ -218,7 +221,7 @@ def build_pipeline_model_from_cfg(config, num_chunks: int = 1, partition_method:
log_str += f'\n===== stage={stage}, layers={num_layers} =====\n' log_str += f'\n===== stage={stage}, layers={num_layers} =====\n'
for st, ed in parts[stage]: for st, ed in parts[stage]:
for idx, layer in enumerate(layers[st: ed]): for idx, layer in enumerate(layers[st:ed]):
log_str += f'\t{idx + st:2d}: {layer}\n' log_str += f'\t{idx + st:2d}: {layer}\n'
logger.info(log_str, ranks=[0]) logger.info(log_str, ranks=[0])

Loading…
Cancel
Save