diff --git a/colossalai/builder/pipeline.py b/colossalai/builder/pipeline.py index a25e8990d..3d14ce23e 100644 --- a/colossalai/builder/pipeline.py +++ b/colossalai/builder/pipeline.py @@ -240,7 +240,6 @@ def build_pipeline_model_from_cfg(config, def build_pipeline_model(layers: nn.Sequential, num_chunks: int = 1, verbose: bool = False): """An intializer to split the model into different stages for pipeline parallelism. Note that `layer` must be `torch.nn.Sequential`. - Args: layers (`torch.nn.Sequential`): Layers of model num_chunks: The number of chunks you want to have on the current stage. This value should be 1 @@ -252,7 +251,9 @@ def build_pipeline_model(layers: nn.Sequential, num_chunks: int = 1, verbose: bo partitions = partition_uniform(len(layers), pipeline_parallel_size, num_chunks) module_list = [] for start, end in partitions[pipeline_rank]: - module_list.append(nn.Sequential(*layers[start:end])) + module_list.append(nn.Sequential(*[nn.Identity() for _ in range(start)], + *layers[start:end], + *[nn.Identity() for _ in range(len(layers) - end)])) if verbose: logger = get_dist_logger() logger.info(f'Total {len(layers)} layers', ranks=[0]) @@ -263,3 +264,4 @@ def build_pipeline_model(layers: nn.Sequential, num_chunks: int = 1, verbose: bo log_str += '\n'.join([str(layer) for layer in layers[start:end]]) + '\n' logger.info(log_str, ranks=[0]) return nn.ModuleList(module_list) if len(module_list) > 1 else module_list[0] + \ No newline at end of file diff --git a/colossalai/initialize.py b/colossalai/initialize.py index 08bd43f62..bdbc96681 100644 --- a/colossalai/initialize.py +++ b/colossalai/initialize.py @@ -399,6 +399,8 @@ def initialize(model: nn.Module, else: scatter_gather = False if use_interleaved: + if isinstance(model, nn.Sequential): + model = nn.ModuleList([model]) schedule = InterleavedPipelineSchedule(gpc.config.NUM_MICRO_BATCHES, gpc.config.model.num_chunks, tensor_shape=tensor_shape, @@ -434,7 +436,6 @@ def initialize(model: nn.Module, accumulate_size=grad_accum_size, gradient_handlers=gradient_handlers, lr_scheduler=lr_scheduler) - engine = Engine(model=model, optimizer=optimizer, criterion=criterion,