mirror of https://github.com/hpcaitech/ColossalAI
[shardformer] support DDP in HybridPlugin/add tp+dp tests (#4446)
* support DDP for HybridPlugin/add tp+dp tests * add docstring for HybridParallelPluginpull/4460/head
parent
424629fea0
commit
6ef33f75aa
|
@ -6,7 +6,8 @@ import numpy as np
|
|||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch.nn import Module
|
||||
from torch.nn import Module, SyncBatchNorm
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
||||
from torch.utils.data import DataLoader
|
||||
|
@ -28,7 +29,8 @@ DP_AXIS, PP_AXIS, TP_AXIS = 0, 1, 2
|
|||
|
||||
class HybridParallelModule(ModelWrapper):
|
||||
|
||||
def __init__(self, module: Module, precision: str, shard_config: ShardConfig, dp_group: ProcessGroup) -> None:
|
||||
def __init__(self, module: Module, precision: str, shard_config: ShardConfig, dp_group: ProcessGroup, use_ddp: bool,
|
||||
ddp_config: dict) -> None:
|
||||
self.stage_manager = shard_config.pipeline_stage_manager
|
||||
self.dp_group = dp_group
|
||||
shardformer = ShardFormer(shard_config)
|
||||
|
@ -45,7 +47,15 @@ class HybridParallelModule(ModelWrapper):
|
|||
module = module.to(dtype=torch.bfloat16).cuda()
|
||||
else:
|
||||
module = module.cuda() # train without AMP
|
||||
# TODO(ver217): support TP+DP
|
||||
|
||||
if use_ddp:
|
||||
|
||||
# convert model to sync bn
|
||||
module = SyncBatchNorm.convert_sync_batchnorm(module, dp_group)
|
||||
|
||||
# wrap the model with PyTorch DDP
|
||||
module = DDP(module, process_group=dp_group, **ddp_config)
|
||||
|
||||
super().__init__(module)
|
||||
|
||||
def sync_shared_params(self):
|
||||
|
@ -68,6 +78,12 @@ class HybridParallelModule(ModelWrapper):
|
|||
dist.all_reduce(p.grad, group=self.dp_group)
|
||||
p.grad.div_(self.dp_group.size())
|
||||
|
||||
def unwrap(self):
|
||||
module = super().unwrap()
|
||||
if isinstance(module, DDP):
|
||||
module = module.module
|
||||
return module
|
||||
|
||||
|
||||
def init_pipeline_optimizer(optim: Optimizer, model: Module):
|
||||
params = set(model.parameters())
|
||||
|
@ -140,29 +156,81 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
|
|||
|
||||
|
||||
class HybridParallelPlugin(PipelinePluginBase):
|
||||
"""
|
||||
Plugin for Hybrid Parallel Training.
|
||||
Tensor parallel, pipeline parallel and data parallel(DDP/ZeRO) can be picked and combined in this plugin.
|
||||
The size of tp and pp should be passed in by user, then the size of dp is automatically calculated from dp_size = world_size / (tp_size * pp_size).
|
||||
|
||||
Example:
|
||||
>>> from colossalai.booster import Booster
|
||||
>>> from colossalai.booster.plugin import HybridParallelPlugin
|
||||
|
||||
>>> model, train_dataset, optimizer, criterion = ...
|
||||
>>> plugin = HybridParallelPlugin(tp_size=2, pp_size=2)
|
||||
|
||||
>>> train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8)
|
||||
>>> booster = Booster(plugin=plugin)
|
||||
>>> model, optimizer, criterion, train_dataloader, _ = booster.boost(model, optimizer, criterion, train_dataloader)
|
||||
|
||||
Args:
|
||||
tp_size (int): The size of tensor parallelism. Tensor parallelism will not be used when tp_size is set to 1.
|
||||
pp_size (int): The number of pipeline stages in pipeline parallelism. Pipeline parallelism will not be used when pp_size is set to 1.
|
||||
precision (str, optional): Specifies the precision of parameters during training.
|
||||
Auto-mixied precision will be used when this argument is set to 'fp16' or 'bf16', otherwise model is trained with 'fp32'.
|
||||
Defaults to 'fp16'.
|
||||
zero_stage (int, optional): The stage of ZeRO for data parallelism. Can only be choosed from [0, 1, 2].
|
||||
When set to 0, ZeRO will not be used. Defaults to 0.
|
||||
cpu_offload (bool, optional): Whether to open cpu_offload when using ZeRO. Defaults to False.
|
||||
enable_all_optimization (bool, optional): Whether to switch on all the optimizations supported by Shardformer.
|
||||
Currently all the optimization methods include fused normalization, flash attention and JIT.
|
||||
Defaults to False.
|
||||
enable_fused_normalization (bool, optional): Whether to switch on fused normalization. Defaults to False.
|
||||
enable_flash_attention (bool, optional): Whether to switch on flash attention. Defaults to False.
|
||||
enable_jit_fused (bool, optional): Whether to switch on JIT. Default to Falase.
|
||||
num_microbatches (int, optional): Number of microbatches when using pipeline parallelism. Defaults to None.
|
||||
initial_scale (float, optional): The initial loss scale of AMP. Defaults to 2**16.
|
||||
min_scale (float, optional): The minimum loss scale of AMP. Defaults to 1.
|
||||
growth_factor (float, optional): The multiplication factor for increasing loss scale when using AMP. Defaults to 2.
|
||||
backoff_factor (float, optional): The multiplication factor for decreasing loss scale when using AMP. Defaults to 0.5.
|
||||
growth_interval (int, optional): The number of steps to increase loss scale when no overflow occurs when using AMP. Defaults to 1000.
|
||||
hysteresis (int, optional): The number of overflows before decreasing loss scale when using AMP. Defaults to 2.
|
||||
max_scale (float, optional): The maximum loss scale of AMP. Defaults to 2**32.
|
||||
max_norm (float, optional): Maximum norm for gradient clipping. Defaults to 0.
|
||||
broadcast_buffers (bool, optional): Whether to broadcast buffers in the beginning of training. Only for usage of DDP. Defaults to True.
|
||||
bucket_cap_mb (int, optional): The bucket size in MB. Only for usage of DDP. Defaults to 25.
|
||||
find_unused_parameters (bool, optional): Whether to find unused parameters. Only for usage of DDP. Defaults to False.
|
||||
check_reduction (bool, optional): Whether to check reduction. Only for usage of DDP. Defaults to False.
|
||||
gradient_as_bucket_view (bool, optional): Whether to use gradient as bucket view. Only for usage of DDP. Defaults to False.
|
||||
static_graph (bool, optional): Whether to use static graph. Only for usage of DDP. Defaults to False.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
tp_size: int,
|
||||
pp_size: int,
|
||||
precision: str = 'fp16',
|
||||
zero_stage: int = 0,
|
||||
cpu_offload: bool = False,
|
||||
enable_all_optimization: bool = False,
|
||||
enable_fused_normalization: bool = False,
|
||||
enable_flash_attention: bool = False,
|
||||
enable_jit_fused: bool = False,
|
||||
enable_sequence_parallelism: bool = False,
|
||||
num_microbatches: Optional[int] = None,
|
||||
initial_scale: float = 2**16,
|
||||
min_scale: float = 1,
|
||||
growth_factor: float = 2,
|
||||
backoff_factor: float = 0.5,
|
||||
growth_interval: int = 1000,
|
||||
hysteresis: int = 2,
|
||||
max_scale: float = 2**32,
|
||||
max_norm: float = 0,
|
||||
broadcast_buffers=True,
|
||||
bucket_cap_mb=25,
|
||||
find_unused_parameters=False,
|
||||
check_reduction=False,
|
||||
gradient_as_bucket_view=False,
|
||||
static_graph=False) -> None:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tp_size: int,
|
||||
pp_size: int,
|
||||
precision: str = 'fp16',
|
||||
zero_stage: int = 0,
|
||||
cpu_offload: bool = False,
|
||||
enable_all_optimization: bool = False,
|
||||
enable_fused_normalization: bool = False,
|
||||
enable_flash_attention: bool = False,
|
||||
enable_jit_fused: bool = False,
|
||||
enable_sequence_parallelism: bool = False,
|
||||
num_microbatches: Optional[int] = None,
|
||||
initial_scale: float = 2**16,
|
||||
min_scale: float = 1,
|
||||
growth_factor: float = 2,
|
||||
backoff_factor: float = 0.5,
|
||||
growth_interval: int = 1000,
|
||||
hysteresis: int = 2,
|
||||
max_scale: float = 2**32,
|
||||
max_norm: float = 0,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
assert dist.get_world_size() % (
|
||||
tp_size * pp_size
|
||||
|
@ -208,6 +276,13 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||
min_scale=min_scale,
|
||||
max_scale=max_scale,
|
||||
)
|
||||
|
||||
self.ddp_config = dict(broadcast_buffers=broadcast_buffers,
|
||||
bucket_cap_mb=bucket_cap_mb,
|
||||
find_unused_parameters=find_unused_parameters,
|
||||
check_reduction=check_reduction,
|
||||
gradient_as_bucket_view=gradient_as_bucket_view,
|
||||
static_graph=static_graph)
|
||||
self.max_norm = max_norm
|
||||
|
||||
@property
|
||||
|
@ -241,7 +316,9 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||
lr_scheduler: Optional[LRScheduler] = None,
|
||||
) -> Tuple[Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
|
||||
if not isinstance(model, ModelWrapper):
|
||||
model = HybridParallelModule(model, self.precision, self.shard_config, self.dp_group)
|
||||
use_ddp = self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0
|
||||
model = HybridParallelModule(model, self.precision, self.shard_config, self.dp_group, use_ddp,
|
||||
self.ddp_config)
|
||||
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
|
||||
if self.zero_stage == 0:
|
||||
if self.precision in ['fp16', 'bf16']:
|
||||
|
|
|
@ -13,6 +13,7 @@ from torch.optim import Adam, Optimizer
|
|||
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import HybridParallelPlugin
|
||||
from colossalai.booster.plugin.hybrid_parallel_plugin import HybridParallelModule
|
||||
from colossalai.lazy import LazyInitContext
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer import ShardConfig, ShardFormer
|
||||
|
@ -259,3 +260,15 @@ def check_grad(org_model: Module,
|
|||
assert torch.allclose(
|
||||
org_grad.float(), shard_grad.float(), rtol=rtol, atol=atol
|
||||
), f"error attribute '{suffix}', orgin model grad is not equal to shard model grad\n{org_grad}\n{shard_grad}"
|
||||
|
||||
|
||||
def unwrap_model(module: Module,
|
||||
base_model_class_name: Optional[str] = None,
|
||||
base_model_attribute_name: Optional[str] = None):
|
||||
if isinstance(module, HybridParallelModule):
|
||||
module = module.unwrap()
|
||||
if base_model_class_name is None:
|
||||
return module
|
||||
if module.__class__.__name__ == base_model_class_name:
|
||||
return module
|
||||
return getattr(module, base_model_attribute_name, None)
|
||||
|
|
|
@ -15,6 +15,7 @@ from tests.test_shardformer.test_model._utils import (
|
|||
check_output_hidden_state,
|
||||
check_weight,
|
||||
run_forward_backward_with_hybrid_plugin,
|
||||
unwrap_model,
|
||||
)
|
||||
|
||||
|
||||
|
@ -44,13 +45,9 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol)
|
||||
|
||||
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
|
||||
# unwrap model
|
||||
if org_model.__class__.__name__ == 'BertModel':
|
||||
bert = org_model
|
||||
sharded_bert = sharded_model.unwrap()
|
||||
else:
|
||||
bert = org_model.bert
|
||||
sharded_bert = sharded_model.unwrap().bert
|
||||
|
||||
bert = unwrap_model(org_model, 'BertModel', 'bert')
|
||||
sharded_bert = unwrap_model(sharded_model, 'BertModel', 'bert')
|
||||
|
||||
col_layer_for_check = ['encoder.layer[0].output.dense']
|
||||
row_layer_for_check = ['embeddings.word_embeddings', 'encoder.layer[0].intermediate.dense']
|
||||
|
@ -98,6 +95,12 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
'enable_all_optimization': True,
|
||||
'use_lazy_init': False,
|
||||
'precision': 'fp32',
|
||||
}, {
|
||||
'tp_size': 2,
|
||||
'pp_size': 1,
|
||||
'enable_all_optimization': True,
|
||||
'use_lazy_init': False,
|
||||
'precision': 'fp32'
|
||||
}])
|
||||
def run_bert_test(test_config):
|
||||
|
||||
|
|
|
@ -13,6 +13,7 @@ from tests.test_shardformer.test_model._utils import (
|
|||
check_output_hidden_state,
|
||||
check_weight,
|
||||
run_forward_backward_with_hybrid_plugin,
|
||||
unwrap_model,
|
||||
)
|
||||
|
||||
|
||||
|
@ -46,12 +47,8 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
|
||||
|
||||
# unwrap model
|
||||
if org_model.__class__.__name__ == 'BloomModel':
|
||||
bloom = org_model
|
||||
sharded_bloom = sharded_model.unwrap()
|
||||
else:
|
||||
bloom = org_model.transformer
|
||||
sharded_bloom = sharded_model.unwrap().transformer
|
||||
bloom = unwrap_model(org_model, 'BloomModel', 'transformer')
|
||||
sharded_bloom = unwrap_model(sharded_model, 'BloomModel', 'transformer')
|
||||
|
||||
# check grad
|
||||
row_layer_for_check = ['h[0].self_attention.query_key_value', 'word_embeddings']
|
||||
|
@ -97,12 +94,16 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
'pp_size': 1,
|
||||
'enable_all_optimization': True,
|
||||
'use_lazy_init': False,
|
||||
'precision': 'fp32',
|
||||
'precision': 'fp32'
|
||||
}, {
|
||||
'tp_size': 2,
|
||||
'pp_size': 1,
|
||||
'enable_all_optimization': True,
|
||||
'use_lazy_init': False,
|
||||
'precision': 'fp32'
|
||||
}])
|
||||
def run_bloom_test(test_config):
|
||||
|
||||
# TODO(baizhou): add test_config for TP+DP after supporting & debugging it
|
||||
|
||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_bloom')
|
||||
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
|
|
|
@ -14,6 +14,7 @@ from tests.test_shardformer.test_model._utils import (
|
|||
check_output_hidden_state,
|
||||
check_weight,
|
||||
run_forward_backward_with_hybrid_plugin,
|
||||
unwrap_model,
|
||||
)
|
||||
|
||||
|
||||
|
@ -48,12 +49,8 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
|
||||
|
||||
# unwrap model
|
||||
if org_model.__class__.__name__ == 'ChatGLMModel':
|
||||
chatglm_model = org_model
|
||||
shard_chatglm_model = sharded_model.unwrap()
|
||||
else:
|
||||
chatglm_model = org_model.transformer
|
||||
shard_chatglm_model = sharded_model.unwrap().transformer
|
||||
chatglm_model = unwrap_model(org_model, 'ChatGLMModel', 'transformer')
|
||||
shard_chatglm_model = unwrap_model(sharded_model, 'ChatGLMModel', 'transformer')
|
||||
|
||||
# check grad
|
||||
row_layer_for_check = ['encoder.layers[0].self_attention.query_key_value', 'embedding.word_embeddings']
|
||||
|
@ -121,12 +118,16 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
'pp_size': 1,
|
||||
'enable_all_optimization': True,
|
||||
'use_lazy_init': False,
|
||||
'precision': 'fp32',
|
||||
'precision': 'fp32'
|
||||
}, {
|
||||
'tp_size': 2,
|
||||
'pp_size': 1,
|
||||
'enable_all_optimization': True,
|
||||
'use_lazy_init': False,
|
||||
'precision': 'fp32'
|
||||
}])
|
||||
def run_chatglm_test(test_config):
|
||||
|
||||
# TODO(baizhou): add test_config for TP+DP after supporting & debugging it
|
||||
|
||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_chatglm')
|
||||
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
|
|
|
@ -3,7 +3,6 @@ import torch
|
|||
from torch import distributed as dist
|
||||
|
||||
import colossalai
|
||||
from colossalai.booster.plugin.hybrid_parallel_plugin import HybridParallelModule
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.tensor.d_tensor.api import clear_layout_converter
|
||||
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
|
||||
|
@ -15,6 +14,7 @@ from tests.test_shardformer.test_model._utils import (
|
|||
check_output_hidden_state,
|
||||
check_weight,
|
||||
run_forward_backward_with_hybrid_plugin,
|
||||
unwrap_model,
|
||||
)
|
||||
|
||||
|
||||
|
@ -48,16 +48,9 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
|
||||
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
|
||||
|
||||
def unwrap(module):
|
||||
if isinstance(module, HybridParallelModule):
|
||||
module = module.unwrap()
|
||||
if module.__class__.__name__ == 'GPT2Model':
|
||||
return module
|
||||
return module.transformer
|
||||
|
||||
# unwrap model
|
||||
gpt2 = unwrap(org_model)
|
||||
sharded_gpt2 = unwrap(sharded_model)
|
||||
gpt2 = unwrap_model(org_model, 'GPT2Model', 'transformer')
|
||||
sharded_gpt2 = unwrap_model(sharded_model, 'GPT2Model', 'transformer')
|
||||
|
||||
col_layer_for_check = ['h[0].mlp.c_fc']
|
||||
row_layer_for_check = ['wte', 'h[0].mlp.c_proj']
|
||||
|
@ -106,6 +99,12 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
'enable_all_optimization': True,
|
||||
'use_lazy_init': False,
|
||||
'precision': 'fp32',
|
||||
}, {
|
||||
'tp_size': 2,
|
||||
'pp_size': 1,
|
||||
'enable_all_optimization': True,
|
||||
'use_lazy_init': False,
|
||||
'precision': 'fp32',
|
||||
}, {
|
||||
'tp_size': 4,
|
||||
'pp_size': 1,
|
||||
|
@ -117,8 +116,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
@clear_cache_before_run()
|
||||
def run_gpt2_test(test_config):
|
||||
|
||||
# TODO(baizhou): add test_config for TP+DP after supporting & debugging it
|
||||
|
||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt')
|
||||
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
|
|
|
@ -16,6 +16,7 @@ from tests.test_shardformer.test_model._utils import (
|
|||
check_output_hidden_state,
|
||||
check_weight,
|
||||
run_forward_backward_with_hybrid_plugin,
|
||||
unwrap_model,
|
||||
)
|
||||
|
||||
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true'
|
||||
|
@ -52,12 +53,8 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
|
||||
|
||||
# unwrap model
|
||||
if org_model.__class__.__name__ == 'LlamaModel':
|
||||
llama_model = org_model
|
||||
shard_llama_model = sharded_model.unwrap()
|
||||
else:
|
||||
llama_model = org_model.model
|
||||
shard_llama_model = sharded_model.unwrap().model
|
||||
llama_model = unwrap_model(org_model, 'LlamaModel', 'model')
|
||||
shard_llama_model = unwrap_model(sharded_model, 'LlamaModel', 'model')
|
||||
|
||||
# check grad
|
||||
row_layer_for_check = ['layers[0].self_attn.q_proj', 'embed_tokens']
|
||||
|
@ -128,13 +125,18 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
'tp_size': 1,
|
||||
'pp_size': 4,
|
||||
'num_microbatches': 4,
|
||||
'enable_all_optimization': False,
|
||||
'use_lazy_init': False,
|
||||
'precision': 'fp32',
|
||||
'precision': 'fp32'
|
||||
}, {
|
||||
'tp_size': 2,
|
||||
'pp_size': 1,
|
||||
'enable_all_optimization': True,
|
||||
'use_lazy_init': False,
|
||||
'precision': 'fp32'
|
||||
}])
|
||||
def run_llama_test(test_config):
|
||||
|
||||
# TODO(baizhou): add test_config for TP+DP after supporting & debugging it
|
||||
|
||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_llama')
|
||||
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
|
|
|
@ -16,6 +16,7 @@ from tests.test_shardformer.test_model._utils import (
|
|||
check_output_hidden_state,
|
||||
check_weight,
|
||||
run_forward_backward_with_hybrid_plugin,
|
||||
unwrap_model,
|
||||
)
|
||||
|
||||
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true'
|
||||
|
@ -51,12 +52,8 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
|
||||
|
||||
# unwrap model
|
||||
if org_model.__class__.__name__ == 'OPTModel':
|
||||
opt_model = org_model
|
||||
shard_opt_model = sharded_model.unwrap()
|
||||
else:
|
||||
opt_model = org_model.model
|
||||
shard_opt_model = sharded_model.unwrap().model
|
||||
opt_model = unwrap_model(org_model, 'OPTModel', 'model')
|
||||
shard_opt_model = unwrap_model(sharded_model, 'OPTModel', 'model')
|
||||
|
||||
# check grad
|
||||
row_layer_for_check = ['decoder.layers[0].self_attn.q_proj', 'decoder.embed_tokens'] # 'decoder.embed_tokens'
|
||||
|
@ -123,14 +120,17 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
'pp_size': 1,
|
||||
'enable_all_optimization': True,
|
||||
'use_lazy_init': False,
|
||||
'precision': 'fp32',
|
||||
'precision': 'fp32'
|
||||
}, {
|
||||
'tp_size': 2,
|
||||
'pp_size': 1,
|
||||
'enable_all_optimization': True,
|
||||
'use_lazy_init': False,
|
||||
'precision': 'fp32'
|
||||
}])
|
||||
def run_opt_test(test_config):
|
||||
|
||||
# TODO(baizhou): add test_config for TP+DP after supporting & debugging it
|
||||
|
||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_opt')
|
||||
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
|
||||
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
import pytest
|
||||
import torch
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
|
||||
import colossalai
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
|
@ -14,6 +15,7 @@ from tests.test_shardformer.test_model._utils import (
|
|||
check_output_hidden_state,
|
||||
check_weight,
|
||||
run_forward_backward_with_hybrid_plugin,
|
||||
unwrap_model,
|
||||
)
|
||||
|
||||
|
||||
|
@ -48,8 +50,8 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
|
||||
|
||||
# unwrap model
|
||||
t5 = org_model
|
||||
sharded_t5 = sharded_model.unwrap()
|
||||
t5 = unwrap_model(org_model)
|
||||
sharded_t5 = unwrap_model(sharded_model)
|
||||
|
||||
row_layer_for_check = ['shared', 'encoder.block[0].layer[0].SelfAttention.q']
|
||||
|
||||
|
@ -99,17 +101,19 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
'tp_size': 1,
|
||||
'pp_size': 4,
|
||||
'num_microbatches': 4,
|
||||
'enable_all_optimization': False,
|
||||
'use_lazy_init': False,
|
||||
'precision': 'fp32',
|
||||
'precision': 'fp32'
|
||||
}, {
|
||||
'tp_size': 2,
|
||||
'pp_size': 1,
|
||||
'enable_all_optimization': True,
|
||||
'use_lazy_init': False,
|
||||
'precision': 'fp32'
|
||||
}])
|
||||
@clear_cache_before_run()
|
||||
def run_t5_test(test_config):
|
||||
|
||||
# TODO(baizhou): add plugin_config for TP+DP after supporting & debugging it
|
||||
# {'tp_size': 2, 'pp_size': 1, 'enable_fused_normalization': True}
|
||||
|
||||
# TODO(baizhou): add test_config for flash attention & jit operator after supporting
|
||||
|
||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_t5')
|
||||
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
|
|
|
@ -14,6 +14,7 @@ from tests.test_shardformer.test_model._utils import (
|
|||
check_output_hidden_state,
|
||||
check_weight,
|
||||
run_forward_backward_with_hybrid_plugin,
|
||||
unwrap_model,
|
||||
)
|
||||
|
||||
|
||||
|
@ -48,12 +49,8 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
|
||||
|
||||
# unwrap model
|
||||
if org_model.__class__.__name__ == 'ViTModel':
|
||||
vit_model = org_model
|
||||
shard_vit_model = sharded_model.unwrap()
|
||||
else:
|
||||
vit_model = org_model.vit
|
||||
shard_vit_model = sharded_model.unwrap().vit
|
||||
vit_model = unwrap_model(org_model, 'ViTModel', 'vit')
|
||||
shard_vit_model = unwrap_model(sharded_model, 'ViTModel', 'vit')
|
||||
|
||||
# check grad
|
||||
row_layer_for_check = ['encoder.layer[0].attention.attention.query', 'embeddings.patch_embeddings.projection']
|
||||
|
@ -120,15 +117,19 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
'pp_size': 1,
|
||||
'enable_all_optimization': True,
|
||||
'use_lazy_init': False,
|
||||
'precision': 'fp32',
|
||||
'precision': 'fp32'
|
||||
}, {
|
||||
'tp_size': 2,
|
||||
'pp_size': 1,
|
||||
'enable_all_optimization': True,
|
||||
'use_lazy_init': False,
|
||||
'precision': 'fp32'
|
||||
}])
|
||||
def run_vit_test(test_config):
|
||||
|
||||
# TODO(baizhou): add test_config for TP+DP after supporting & debugging it
|
||||
# TODO(baizhou): fix bug when settign lazy_init for Conv2D Layers in ViT models
|
||||
# TODO: fix bug when settign lazy_init for Conv2D Layers in ViT models
|
||||
|
||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_vit')
|
||||
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
|
||||
|
||||
|
|
Loading…
Reference in New Issue