mirror of https://github.com/hpcaitech/ColossalAI
modefied the pp build for ckpt adaptation (#803)
parent
8789850eea
commit
c1e8d2001e
|
@ -240,7 +240,6 @@ def build_pipeline_model_from_cfg(config,
|
||||||
def build_pipeline_model(layers: nn.Sequential, num_chunks: int = 1, verbose: bool = False):
|
def build_pipeline_model(layers: nn.Sequential, num_chunks: int = 1, verbose: bool = False):
|
||||||
"""An intializer to split the model into different stages for pipeline parallelism.
|
"""An intializer to split the model into different stages for pipeline parallelism.
|
||||||
Note that `layer` must be `torch.nn.Sequential`.
|
Note that `layer` must be `torch.nn.Sequential`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
layers (`torch.nn.Sequential`): Layers of model
|
layers (`torch.nn.Sequential`): Layers of model
|
||||||
num_chunks: The number of chunks you want to have on the current stage. This value should be 1
|
num_chunks: The number of chunks you want to have on the current stage. This value should be 1
|
||||||
|
@ -252,7 +251,9 @@ def build_pipeline_model(layers: nn.Sequential, num_chunks: int = 1, verbose: bo
|
||||||
partitions = partition_uniform(len(layers), pipeline_parallel_size, num_chunks)
|
partitions = partition_uniform(len(layers), pipeline_parallel_size, num_chunks)
|
||||||
module_list = []
|
module_list = []
|
||||||
for start, end in partitions[pipeline_rank]:
|
for start, end in partitions[pipeline_rank]:
|
||||||
module_list.append(nn.Sequential(*layers[start:end]))
|
module_list.append(nn.Sequential(*[nn.Identity() for _ in range(start)],
|
||||||
|
*layers[start:end],
|
||||||
|
*[nn.Identity() for _ in range(len(layers) - end)]))
|
||||||
if verbose:
|
if verbose:
|
||||||
logger = get_dist_logger()
|
logger = get_dist_logger()
|
||||||
logger.info(f'Total {len(layers)} layers', ranks=[0])
|
logger.info(f'Total {len(layers)} layers', ranks=[0])
|
||||||
|
@ -263,3 +264,4 @@ def build_pipeline_model(layers: nn.Sequential, num_chunks: int = 1, verbose: bo
|
||||||
log_str += '\n'.join([str(layer) for layer in layers[start:end]]) + '\n'
|
log_str += '\n'.join([str(layer) for layer in layers[start:end]]) + '\n'
|
||||||
logger.info(log_str, ranks=[0])
|
logger.info(log_str, ranks=[0])
|
||||||
return nn.ModuleList(module_list) if len(module_list) > 1 else module_list[0]
|
return nn.ModuleList(module_list) if len(module_list) > 1 else module_list[0]
|
||||||
|
|
|
@ -399,6 +399,8 @@ def initialize(model: nn.Module,
|
||||||
else:
|
else:
|
||||||
scatter_gather = False
|
scatter_gather = False
|
||||||
if use_interleaved:
|
if use_interleaved:
|
||||||
|
if isinstance(model, nn.Sequential):
|
||||||
|
model = nn.ModuleList([model])
|
||||||
schedule = InterleavedPipelineSchedule(gpc.config.NUM_MICRO_BATCHES,
|
schedule = InterleavedPipelineSchedule(gpc.config.NUM_MICRO_BATCHES,
|
||||||
gpc.config.model.num_chunks,
|
gpc.config.model.num_chunks,
|
||||||
tensor_shape=tensor_shape,
|
tensor_shape=tensor_shape,
|
||||||
|
@ -434,7 +436,6 @@ def initialize(model: nn.Module,
|
||||||
accumulate_size=grad_accum_size,
|
accumulate_size=grad_accum_size,
|
||||||
gradient_handlers=gradient_handlers,
|
gradient_handlers=gradient_handlers,
|
||||||
lr_scheduler=lr_scheduler)
|
lr_scheduler=lr_scheduler)
|
||||||
|
|
||||||
engine = Engine(model=model,
|
engine = Engine(model=model,
|
||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
criterion=criterion,
|
criterion=criterion,
|
||||||
|
|
Loading…
Reference in New Issue