|
|
|
@ -251,9 +251,9 @@ def build_pipeline_model(layers: nn.Sequential, num_chunks: int = 1, verbose: bo
|
|
|
|
|
partitions = partition_uniform(len(layers), pipeline_parallel_size, num_chunks) |
|
|
|
|
module_list = [] |
|
|
|
|
for start, end in partitions[pipeline_rank]: |
|
|
|
|
module_list.append(nn.Sequential(*[nn.Identity() for _ in range(start)], |
|
|
|
|
*layers[start:end], |
|
|
|
|
*[nn.Identity() for _ in range(len(layers) - end)])) |
|
|
|
|
module_list.append( |
|
|
|
|
nn.Sequential(*[nn.Identity() for _ in range(start)], *layers[start:end], |
|
|
|
|
*[nn.Identity() for _ in range(len(layers) - end)])) |
|
|
|
|
if verbose: |
|
|
|
|
logger = get_dist_logger() |
|
|
|
|
logger.info(f'Total {len(layers)} layers', ranks=[0]) |
|
|
|
@ -264,4 +264,3 @@ def build_pipeline_model(layers: nn.Sequential, num_chunks: int = 1, verbose: bo
|
|
|
|
|
log_str += '\n'.join([str(layer) for layer in layers[start:end]]) + '\n' |
|
|
|
|
logger.info(log_str, ranks=[0]) |
|
|
|
|
return nn.ModuleList(module_list) if len(module_list) > 1 else module_list[0] |
|
|
|
|
|