[shardformer] support DDP in HybridPlugin/add tp+dp tests (#4446)

* support DDP for HybridPlugin/add tp+dp tests

* add docstring for HybridParallelPlugin
pull/4460/head
Baizhou Zhang 2023-08-16 16:11:57 +08:00 committed by GitHub
parent 424629fea0
commit 6ef33f75aa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 199 additions and 100 deletions

View File

@ -6,7 +6,8 @@ import numpy as np
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
from torch.nn import Module
from torch.nn import Module, SyncBatchNorm
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils.data import DataLoader
@ -28,7 +29,8 @@ DP_AXIS, PP_AXIS, TP_AXIS = 0, 1, 2
class HybridParallelModule(ModelWrapper):
def __init__(self, module: Module, precision: str, shard_config: ShardConfig, dp_group: ProcessGroup) -> None:
def __init__(self, module: Module, precision: str, shard_config: ShardConfig, dp_group: ProcessGroup, use_ddp: bool,
ddp_config: dict) -> None:
self.stage_manager = shard_config.pipeline_stage_manager
self.dp_group = dp_group
shardformer = ShardFormer(shard_config)
@ -45,7 +47,15 @@ class HybridParallelModule(ModelWrapper):
module = module.to(dtype=torch.bfloat16).cuda()
else:
module = module.cuda() # train without AMP
# TODO(ver217): support TP+DP
if use_ddp:
# convert model to sync bn
module = SyncBatchNorm.convert_sync_batchnorm(module, dp_group)
# wrap the model with PyTorch DDP
module = DDP(module, process_group=dp_group, **ddp_config)
super().__init__(module)
def sync_shared_params(self):
@ -68,6 +78,12 @@ class HybridParallelModule(ModelWrapper):
dist.all_reduce(p.grad, group=self.dp_group)
p.grad.div_(self.dp_group.size())
def unwrap(self):
module = super().unwrap()
if isinstance(module, DDP):
module = module.module
return module
def init_pipeline_optimizer(optim: Optimizer, model: Module):
params = set(model.parameters())
@ -140,29 +156,81 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
class HybridParallelPlugin(PipelinePluginBase):
"""
Plugin for Hybrid Parallel Training.
Tensor parallel, pipeline parallel and data parallel(DDP/ZeRO) can be picked and combined in this plugin.
The size of tp and pp should be passed in by user, then the size of dp is automatically calculated from dp_size = world_size / (tp_size * pp_size).
Example:
>>> from colossalai.booster import Booster
>>> from colossalai.booster.plugin import HybridParallelPlugin
>>> model, train_dataset, optimizer, criterion = ...
>>> plugin = HybridParallelPlugin(tp_size=2, pp_size=2)
>>> train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8)
>>> booster = Booster(plugin=plugin)
>>> model, optimizer, criterion, train_dataloader, _ = booster.boost(model, optimizer, criterion, train_dataloader)
Args:
tp_size (int): The size of tensor parallelism. Tensor parallelism will not be used when tp_size is set to 1.
pp_size (int): The number of pipeline stages in pipeline parallelism. Pipeline parallelism will not be used when pp_size is set to 1.
precision (str, optional): Specifies the precision of parameters during training.
Auto-mixied precision will be used when this argument is set to 'fp16' or 'bf16', otherwise model is trained with 'fp32'.
Defaults to 'fp16'.
zero_stage (int, optional): The stage of ZeRO for data parallelism. Can only be choosed from [0, 1, 2].
When set to 0, ZeRO will not be used. Defaults to 0.
cpu_offload (bool, optional): Whether to open cpu_offload when using ZeRO. Defaults to False.
enable_all_optimization (bool, optional): Whether to switch on all the optimizations supported by Shardformer.
Currently all the optimization methods include fused normalization, flash attention and JIT.
Defaults to False.
enable_fused_normalization (bool, optional): Whether to switch on fused normalization. Defaults to False.
enable_flash_attention (bool, optional): Whether to switch on flash attention. Defaults to False.
enable_jit_fused (bool, optional): Whether to switch on JIT. Default to Falase.
num_microbatches (int, optional): Number of microbatches when using pipeline parallelism. Defaults to None.
initial_scale (float, optional): The initial loss scale of AMP. Defaults to 2**16.
min_scale (float, optional): The minimum loss scale of AMP. Defaults to 1.
growth_factor (float, optional): The multiplication factor for increasing loss scale when using AMP. Defaults to 2.
backoff_factor (float, optional): The multiplication factor for decreasing loss scale when using AMP. Defaults to 0.5.
growth_interval (int, optional): The number of steps to increase loss scale when no overflow occurs when using AMP. Defaults to 1000.
hysteresis (int, optional): The number of overflows before decreasing loss scale when using AMP. Defaults to 2.
max_scale (float, optional): The maximum loss scale of AMP. Defaults to 2**32.
max_norm (float, optional): Maximum norm for gradient clipping. Defaults to 0.
broadcast_buffers (bool, optional): Whether to broadcast buffers in the beginning of training. Only for usage of DDP. Defaults to True.
bucket_cap_mb (int, optional): The bucket size in MB. Only for usage of DDP. Defaults to 25.
find_unused_parameters (bool, optional): Whether to find unused parameters. Only for usage of DDP. Defaults to False.
check_reduction (bool, optional): Whether to check reduction. Only for usage of DDP. Defaults to False.
gradient_as_bucket_view (bool, optional): Whether to use gradient as bucket view. Only for usage of DDP. Defaults to False.
static_graph (bool, optional): Whether to use static graph. Only for usage of DDP. Defaults to False.
"""
def __init__(self,
tp_size: int,
pp_size: int,
precision: str = 'fp16',
zero_stage: int = 0,
cpu_offload: bool = False,
enable_all_optimization: bool = False,
enable_fused_normalization: bool = False,
enable_flash_attention: bool = False,
enable_jit_fused: bool = False,
enable_sequence_parallelism: bool = False,
num_microbatches: Optional[int] = None,
initial_scale: float = 2**16,
min_scale: float = 1,
growth_factor: float = 2,
backoff_factor: float = 0.5,
growth_interval: int = 1000,
hysteresis: int = 2,
max_scale: float = 2**32,
max_norm: float = 0,
broadcast_buffers=True,
bucket_cap_mb=25,
find_unused_parameters=False,
check_reduction=False,
gradient_as_bucket_view=False,
static_graph=False) -> None:
def __init__(
self,
tp_size: int,
pp_size: int,
precision: str = 'fp16',
zero_stage: int = 0,
cpu_offload: bool = False,
enable_all_optimization: bool = False,
enable_fused_normalization: bool = False,
enable_flash_attention: bool = False,
enable_jit_fused: bool = False,
enable_sequence_parallelism: bool = False,
num_microbatches: Optional[int] = None,
initial_scale: float = 2**16,
min_scale: float = 1,
growth_factor: float = 2,
backoff_factor: float = 0.5,
growth_interval: int = 1000,
hysteresis: int = 2,
max_scale: float = 2**32,
max_norm: float = 0,
) -> None:
super().__init__()
assert dist.get_world_size() % (
tp_size * pp_size
@ -208,6 +276,13 @@ class HybridParallelPlugin(PipelinePluginBase):
min_scale=min_scale,
max_scale=max_scale,
)
self.ddp_config = dict(broadcast_buffers=broadcast_buffers,
bucket_cap_mb=bucket_cap_mb,
find_unused_parameters=find_unused_parameters,
check_reduction=check_reduction,
gradient_as_bucket_view=gradient_as_bucket_view,
static_graph=static_graph)
self.max_norm = max_norm
@property
@ -241,7 +316,9 @@ class HybridParallelPlugin(PipelinePluginBase):
lr_scheduler: Optional[LRScheduler] = None,
) -> Tuple[Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
if not isinstance(model, ModelWrapper):
model = HybridParallelModule(model, self.precision, self.shard_config, self.dp_group)
use_ddp = self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0
model = HybridParallelModule(model, self.precision, self.shard_config, self.dp_group, use_ddp,
self.ddp_config)
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
if self.zero_stage == 0:
if self.precision in ['fp16', 'bf16']:

View File

@ -13,6 +13,7 @@ from torch.optim import Adam, Optimizer
from colossalai.booster import Booster
from colossalai.booster.plugin import HybridParallelPlugin
from colossalai.booster.plugin.hybrid_parallel_plugin import HybridParallelModule
from colossalai.lazy import LazyInitContext
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import ShardConfig, ShardFormer
@ -259,3 +260,15 @@ def check_grad(org_model: Module,
assert torch.allclose(
org_grad.float(), shard_grad.float(), rtol=rtol, atol=atol
), f"error attribute '{suffix}', orgin model grad is not equal to shard model grad\n{org_grad}\n{shard_grad}"
def unwrap_model(module: Module,
base_model_class_name: Optional[str] = None,
base_model_attribute_name: Optional[str] = None):
if isinstance(module, HybridParallelModule):
module = module.unwrap()
if base_model_class_name is None:
return module
if module.__class__.__name__ == base_model_class_name:
return module
return getattr(module, base_model_attribute_name, None)

View File

@ -15,6 +15,7 @@ from tests.test_shardformer.test_model._utils import (
check_output_hidden_state,
check_weight,
run_forward_backward_with_hybrid_plugin,
unwrap_model,
)
@ -44,13 +45,9 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol)
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
# unwrap model
if org_model.__class__.__name__ == 'BertModel':
bert = org_model
sharded_bert = sharded_model.unwrap()
else:
bert = org_model.bert
sharded_bert = sharded_model.unwrap().bert
bert = unwrap_model(org_model, 'BertModel', 'bert')
sharded_bert = unwrap_model(sharded_model, 'BertModel', 'bert')
col_layer_for_check = ['encoder.layer[0].output.dense']
row_layer_for_check = ['embeddings.word_embeddings', 'encoder.layer[0].intermediate.dense']
@ -98,6 +95,12 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
'enable_all_optimization': True,
'use_lazy_init': False,
'precision': 'fp32',
}, {
'tp_size': 2,
'pp_size': 1,
'enable_all_optimization': True,
'use_lazy_init': False,
'precision': 'fp32'
}])
def run_bert_test(test_config):

View File

@ -13,6 +13,7 @@ from tests.test_shardformer.test_model._utils import (
check_output_hidden_state,
check_weight,
run_forward_backward_with_hybrid_plugin,
unwrap_model,
)
@ -46,12 +47,8 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
# unwrap model
if org_model.__class__.__name__ == 'BloomModel':
bloom = org_model
sharded_bloom = sharded_model.unwrap()
else:
bloom = org_model.transformer
sharded_bloom = sharded_model.unwrap().transformer
bloom = unwrap_model(org_model, 'BloomModel', 'transformer')
sharded_bloom = unwrap_model(sharded_model, 'BloomModel', 'transformer')
# check grad
row_layer_for_check = ['h[0].self_attention.query_key_value', 'word_embeddings']
@ -97,12 +94,16 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
'pp_size': 1,
'enable_all_optimization': True,
'use_lazy_init': False,
'precision': 'fp32',
'precision': 'fp32'
}, {
'tp_size': 2,
'pp_size': 1,
'enable_all_optimization': True,
'use_lazy_init': False,
'precision': 'fp32'
}])
def run_bloom_test(test_config):
# TODO(baizhou): add test_config for TP+DP after supporting & debugging it
sub_model_zoo = model_zoo.get_sub_registry('transformers_bloom')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():

View File

@ -14,6 +14,7 @@ from tests.test_shardformer.test_model._utils import (
check_output_hidden_state,
check_weight,
run_forward_backward_with_hybrid_plugin,
unwrap_model,
)
@ -48,12 +49,8 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
# unwrap model
if org_model.__class__.__name__ == 'ChatGLMModel':
chatglm_model = org_model
shard_chatglm_model = sharded_model.unwrap()
else:
chatglm_model = org_model.transformer
shard_chatglm_model = sharded_model.unwrap().transformer
chatglm_model = unwrap_model(org_model, 'ChatGLMModel', 'transformer')
shard_chatglm_model = unwrap_model(sharded_model, 'ChatGLMModel', 'transformer')
# check grad
row_layer_for_check = ['encoder.layers[0].self_attention.query_key_value', 'embedding.word_embeddings']
@ -121,12 +118,16 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
'pp_size': 1,
'enable_all_optimization': True,
'use_lazy_init': False,
'precision': 'fp32',
'precision': 'fp32'
}, {
'tp_size': 2,
'pp_size': 1,
'enable_all_optimization': True,
'use_lazy_init': False,
'precision': 'fp32'
}])
def run_chatglm_test(test_config):
# TODO(baizhou): add test_config for TP+DP after supporting & debugging it
sub_model_zoo = model_zoo.get_sub_registry('transformers_chatglm')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():

View File

@ -3,7 +3,6 @@ import torch
from torch import distributed as dist
import colossalai
from colossalai.booster.plugin.hybrid_parallel_plugin import HybridParallelModule
from colossalai.logging import disable_existing_loggers
from colossalai.tensor.d_tensor.api import clear_layout_converter
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
@ -15,6 +14,7 @@ from tests.test_shardformer.test_model._utils import (
check_output_hidden_state,
check_weight,
run_forward_backward_with_hybrid_plugin,
unwrap_model,
)
@ -48,16 +48,9 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
def unwrap(module):
if isinstance(module, HybridParallelModule):
module = module.unwrap()
if module.__class__.__name__ == 'GPT2Model':
return module
return module.transformer
# unwrap model
gpt2 = unwrap(org_model)
sharded_gpt2 = unwrap(sharded_model)
gpt2 = unwrap_model(org_model, 'GPT2Model', 'transformer')
sharded_gpt2 = unwrap_model(sharded_model, 'GPT2Model', 'transformer')
col_layer_for_check = ['h[0].mlp.c_fc']
row_layer_for_check = ['wte', 'h[0].mlp.c_proj']
@ -106,6 +99,12 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
'enable_all_optimization': True,
'use_lazy_init': False,
'precision': 'fp32',
}, {
'tp_size': 2,
'pp_size': 1,
'enable_all_optimization': True,
'use_lazy_init': False,
'precision': 'fp32',
}, {
'tp_size': 4,
'pp_size': 1,
@ -117,8 +116,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
@clear_cache_before_run()
def run_gpt2_test(test_config):
# TODO(baizhou): add test_config for TP+DP after supporting & debugging it
sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():

View File

@ -16,6 +16,7 @@ from tests.test_shardformer.test_model._utils import (
check_output_hidden_state,
check_weight,
run_forward_backward_with_hybrid_plugin,
unwrap_model,
)
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true'
@ -52,12 +53,8 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
# unwrap model
if org_model.__class__.__name__ == 'LlamaModel':
llama_model = org_model
shard_llama_model = sharded_model.unwrap()
else:
llama_model = org_model.model
shard_llama_model = sharded_model.unwrap().model
llama_model = unwrap_model(org_model, 'LlamaModel', 'model')
shard_llama_model = unwrap_model(sharded_model, 'LlamaModel', 'model')
# check grad
row_layer_for_check = ['layers[0].self_attn.q_proj', 'embed_tokens']
@ -128,13 +125,18 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
'tp_size': 1,
'pp_size': 4,
'num_microbatches': 4,
'enable_all_optimization': False,
'use_lazy_init': False,
'precision': 'fp32',
'precision': 'fp32'
}, {
'tp_size': 2,
'pp_size': 1,
'enable_all_optimization': True,
'use_lazy_init': False,
'precision': 'fp32'
}])
def run_llama_test(test_config):
# TODO(baizhou): add test_config for TP+DP after supporting & debugging it
sub_model_zoo = model_zoo.get_sub_registry('transformers_llama')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():

View File

@ -16,6 +16,7 @@ from tests.test_shardformer.test_model._utils import (
check_output_hidden_state,
check_weight,
run_forward_backward_with_hybrid_plugin,
unwrap_model,
)
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true'
@ -51,12 +52,8 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
# unwrap model
if org_model.__class__.__name__ == 'OPTModel':
opt_model = org_model
shard_opt_model = sharded_model.unwrap()
else:
opt_model = org_model.model
shard_opt_model = sharded_model.unwrap().model
opt_model = unwrap_model(org_model, 'OPTModel', 'model')
shard_opt_model = unwrap_model(sharded_model, 'OPTModel', 'model')
# check grad
row_layer_for_check = ['decoder.layers[0].self_attn.q_proj', 'decoder.embed_tokens'] # 'decoder.embed_tokens'
@ -123,14 +120,17 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
'pp_size': 1,
'enable_all_optimization': True,
'use_lazy_init': False,
'precision': 'fp32',
'precision': 'fp32'
}, {
'tp_size': 2,
'pp_size': 1,
'enable_all_optimization': True,
'use_lazy_init': False,
'precision': 'fp32'
}])
def run_opt_test(test_config):
# TODO(baizhou): add test_config for TP+DP after supporting & debugging it
sub_model_zoo = model_zoo.get_sub_registry('transformers_opt')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)

View File

@ -1,5 +1,6 @@
import pytest
import torch
from torch.nn.parallel import DistributedDataParallel as DDP
import colossalai
from colossalai.logging import disable_existing_loggers
@ -14,6 +15,7 @@ from tests.test_shardformer.test_model._utils import (
check_output_hidden_state,
check_weight,
run_forward_backward_with_hybrid_plugin,
unwrap_model,
)
@ -48,8 +50,8 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
# unwrap model
t5 = org_model
sharded_t5 = sharded_model.unwrap()
t5 = unwrap_model(org_model)
sharded_t5 = unwrap_model(sharded_model)
row_layer_for_check = ['shared', 'encoder.block[0].layer[0].SelfAttention.q']
@ -99,17 +101,19 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
'tp_size': 1,
'pp_size': 4,
'num_microbatches': 4,
'enable_all_optimization': False,
'use_lazy_init': False,
'precision': 'fp32',
'precision': 'fp32'
}, {
'tp_size': 2,
'pp_size': 1,
'enable_all_optimization': True,
'use_lazy_init': False,
'precision': 'fp32'
}])
@clear_cache_before_run()
def run_t5_test(test_config):
# TODO(baizhou): add plugin_config for TP+DP after supporting & debugging it
# {'tp_size': 2, 'pp_size': 1, 'enable_fused_normalization': True}
# TODO(baizhou): add test_config for flash attention & jit operator after supporting
sub_model_zoo = model_zoo.get_sub_registry('transformers_t5')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():

View File

@ -14,6 +14,7 @@ from tests.test_shardformer.test_model._utils import (
check_output_hidden_state,
check_weight,
run_forward_backward_with_hybrid_plugin,
unwrap_model,
)
@ -48,12 +49,8 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
# unwrap model
if org_model.__class__.__name__ == 'ViTModel':
vit_model = org_model
shard_vit_model = sharded_model.unwrap()
else:
vit_model = org_model.vit
shard_vit_model = sharded_model.unwrap().vit
vit_model = unwrap_model(org_model, 'ViTModel', 'vit')
shard_vit_model = unwrap_model(sharded_model, 'ViTModel', 'vit')
# check grad
row_layer_for_check = ['encoder.layer[0].attention.attention.query', 'embeddings.patch_embeddings.projection']
@ -120,15 +117,19 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
'pp_size': 1,
'enable_all_optimization': True,
'use_lazy_init': False,
'precision': 'fp32',
'precision': 'fp32'
}, {
'tp_size': 2,
'pp_size': 1,
'enable_all_optimization': True,
'use_lazy_init': False,
'precision': 'fp32'
}])
def run_vit_test(test_config):
# TODO(baizhou): add test_config for TP+DP after supporting & debugging it
# TODO(baizhou): fix bug when settign lazy_init for Conv2D Layers in ViT models
# TODO: fix bug when settign lazy_init for Conv2D Layers in ViT models
sub_model_zoo = model_zoo.get_sub_registry('transformers_vit')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)