[shardformer] support tp+zero for shardformer (#4472)

* support tp+zero/input type cast for hybridplugin

* add tp+zero tests

* fix bucket arguments
pull/4484/head
Baizhou Zhang 2023-08-21 12:04:52 +08:00 committed by GitHub
parent 8739aa7fa0
commit 1c7df566e2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 136 additions and 37 deletions

View File

@ -1,5 +1,6 @@
import random import random
from contextlib import nullcontext from contextlib import nullcontext
from functools import partial
from typing import Any, Callable, Iterator, List, Optional, Tuple, Union from typing import Any, Callable, Iterator, List, Optional, Tuple, Union
import numpy as np import numpy as np
@ -10,6 +11,7 @@ from torch.nn import Module, SyncBatchNorm
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import Optimizer from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils._pytree import tree_map
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
@ -27,32 +29,49 @@ from .pp_plugin_base import PipelinePluginBase
DP_AXIS, PP_AXIS, TP_AXIS = 0, 1, 2 DP_AXIS, PP_AXIS, TP_AXIS = 0, 1, 2
def _convert_floating_point(x, dtype: torch.dtype = torch.float16):
if isinstance(x, torch.Tensor) and torch.is_floating_point(x):
return x.to(dtype)
return x
class HybridParallelModule(ModelWrapper): class HybridParallelModule(ModelWrapper):
def __init__(self, module: Module, precision: str, shard_config: ShardConfig, dp_group: ProcessGroup, use_ddp: bool, def __init__(self, module: Module, precision: str, shard_config: ShardConfig, dp_group: ProcessGroup, use_ddp: bool,
ddp_config: dict) -> None: ddp_config: dict) -> None:
self.stage_manager = shard_config.pipeline_stage_manager self.stage_manager = shard_config.pipeline_stage_manager
self.dp_group = dp_group self.dp_group = dp_group
shardformer = ShardFormer(shard_config) shardformer = ShardFormer(shard_config)
module, self.shared_params = shardformer.optimize(module) module, self.shared_params = shardformer.optimize(module)
# TODO(ver217): add input type cast
# setting process groups for shared parameters
self.shared_param_process_groups = [] self.shared_param_process_groups = []
for shared_param in self.shared_params: for shared_param in self.shared_params:
if len(shared_param) > 0: if len(shared_param) > 0:
self.shared_param_process_groups.append( self.shared_param_process_groups.append(
self.stage_manager.init_process_group_by_stages(list(shared_param.keys()))) self.stage_manager.init_process_group_by_stages(list(shared_param.keys())))
# setting mixed_precision
self.mixed_precision = None
if precision == 'fp16': if precision == 'fp16':
module = module.half().cuda() self.mixed_precision = torch.float16
elif precision == 'bf16': elif precision == 'bf16':
module = module.to(dtype=torch.bfloat16).cuda() self.mixed_precision = torch.bfloat16
else: if self.mixed_precision is not None:
module = module.cuda() # train without AMP module = module.to(self.mixed_precision)
module = module.cuda()
# setting input type cast when using mixed precision
self.convert_fn = None
if self.mixed_precision is not None:
self.convert_fn = partial(_convert_floating_point, dtype=self.mixed_precision)
# setting ddp configs
if use_ddp: if use_ddp:
# convert model to sync bn # convert model to sync bn
module = SyncBatchNorm.convert_sync_batchnorm(module, dp_group) module = SyncBatchNorm.convert_sync_batchnorm(module, dp_group)
# wrap the model with PyTorch DDP # wrap the model with PyTorch DDP
module = DDP(module, process_group=dp_group, **ddp_config) module = DDP(module, process_group=dp_group, **ddp_config)
@ -78,6 +97,12 @@ class HybridParallelModule(ModelWrapper):
dist.all_reduce(p.grad, group=self.dp_group) dist.all_reduce(p.grad, group=self.dp_group)
p.grad.div_(self.dp_group.size()) p.grad.div_(self.dp_group.size())
def forward(self, *args, **kwargs):
if self.convert_fn is not None:
args = tree_map(self.convert_fn, args)
kwargs = tree_map(self.convert_fn, kwargs)
return super().forward(*args, **kwargs)
def unwrap(self): def unwrap(self):
module = super().unwrap() module = super().unwrap()
if isinstance(module, DDP): if isinstance(module, DDP):
@ -180,7 +205,6 @@ class HybridParallelPlugin(PipelinePluginBase):
Defaults to 'fp16'. Defaults to 'fp16'.
zero_stage (int, optional): The stage of ZeRO for data parallelism. Can only be choosed from [0, 1, 2]. zero_stage (int, optional): The stage of ZeRO for data parallelism. Can only be choosed from [0, 1, 2].
When set to 0, ZeRO will not be used. Defaults to 0. When set to 0, ZeRO will not be used. Defaults to 0.
cpu_offload (bool, optional): Whether to open cpu_offload when using ZeRO. Defaults to False.
enable_all_optimization (bool, optional): Whether to switch on all the optimizations supported by Shardformer. enable_all_optimization (bool, optional): Whether to switch on all the optimizations supported by Shardformer.
Currently all the optimization methods include fused normalization, flash attention and JIT. Currently all the optimization methods include fused normalization, flash attention and JIT.
Defaults to False. Defaults to False.
@ -196,12 +220,16 @@ class HybridParallelPlugin(PipelinePluginBase):
hysteresis (int, optional): The number of overflows before decreasing loss scale when using AMP. Defaults to 2. hysteresis (int, optional): The number of overflows before decreasing loss scale when using AMP. Defaults to 2.
max_scale (float, optional): The maximum loss scale of AMP. Defaults to 2**32. max_scale (float, optional): The maximum loss scale of AMP. Defaults to 2**32.
max_norm (float, optional): Maximum norm for gradient clipping. Defaults to 0. max_norm (float, optional): Maximum norm for gradient clipping. Defaults to 0.
broadcast_buffers (bool, optional): Whether to broadcast buffers in the beginning of training. Only for usage of DDP. Defaults to True. broadcast_buffers (bool, optional): Whether to broadcast buffers in the beginning of training when using DDP. Defaults to True.
bucket_cap_mb (int, optional): The bucket size in MB. Only for usage of DDP. Defaults to 25. ddp_bucket_cap_mb (int, optional): The bucket size in MB when using DDP. Defaults to 25.
find_unused_parameters (bool, optional): Whether to find unused parameters. Only for usage of DDP. Defaults to False. find_unused_parameters (bool, optional): Whether to find unused parameters when using DDP. Defaults to False.
check_reduction (bool, optional): Whether to check reduction. Only for usage of DDP. Defaults to False. check_reduction (bool, optional): Whether to check reduction when using DDP. Defaults to False.
gradient_as_bucket_view (bool, optional): Whether to use gradient as bucket view. Only for usage of DDP. Defaults to False. gradient_as_bucket_view (bool, optional): Whether to use gradient as bucket view when using DDP. Defaults to False.
static_graph (bool, optional): Whether to use static graph. Only for usage of DDP. Defaults to False. static_graph (bool, optional): Whether to use static graph when using DDP. Defaults to False.
zero_bucket_size_in_m (int, optional): Gradient reduce bucket size in million elements when using ZeRO. Defaults to 12.
cpu_offload (bool, optional): Whether to open cpu_offload when using ZeRO. Defaults to False.
communication_dtype (torch.dtype, optional): Communication dtype when using ZeRO. If not specified, the dtype of param will be used. Defaults to None.
overlap_communication (bool, optional): Whether to overlap communication and computation when using ZeRO. Defaults to True.
""" """
def __init__(self, def __init__(self,
@ -209,7 +237,6 @@ class HybridParallelPlugin(PipelinePluginBase):
pp_size: int, pp_size: int,
precision: str = 'fp16', precision: str = 'fp16',
zero_stage: int = 0, zero_stage: int = 0,
cpu_offload: bool = False,
enable_all_optimization: bool = False, enable_all_optimization: bool = False,
enable_fused_normalization: bool = False, enable_fused_normalization: bool = False,
enable_flash_attention: bool = False, enable_flash_attention: bool = False,
@ -224,12 +251,16 @@ class HybridParallelPlugin(PipelinePluginBase):
hysteresis: int = 2, hysteresis: int = 2,
max_scale: float = 2**32, max_scale: float = 2**32,
max_norm: float = 0, max_norm: float = 0,
broadcast_buffers=True, broadcast_buffers: bool = True,
bucket_cap_mb=25, ddp_bucket_cap_mb: int = 25,
find_unused_parameters=False, find_unused_parameters: bool = False,
check_reduction=False, check_reduction: bool = False,
gradient_as_bucket_view=False, gradient_as_bucket_view: bool = False,
static_graph=False) -> None: static_graph: bool = False,
zero_bucket_size_in_m: int = 12,
cpu_offload: bool = False,
communication_dtype: Optional[torch.dtype] = None,
overlap_communication: bool = True) -> None:
super().__init__() super().__init__()
assert dist.get_world_size() % ( assert dist.get_world_size() % (
@ -239,8 +270,6 @@ class HybridParallelPlugin(PipelinePluginBase):
if enable_sequence_parallelism: if enable_sequence_parallelism:
assert tp_size > 1, 'Sequence parallelism must be enabled when using tensor parallelism' assert tp_size > 1, 'Sequence parallelism must be enabled when using tensor parallelism'
# TODO(ver217): support zero
assert zero_stage == 0, 'zero is not support yet'
self.tp_size = tp_size self.tp_size = tp_size
self.pp_size = pp_size self.pp_size = pp_size
self.dp_size = dist.get_world_size() // (tp_size * pp_size) self.dp_size = dist.get_world_size() // (tp_size * pp_size)
@ -282,11 +311,18 @@ class HybridParallelPlugin(PipelinePluginBase):
) )
self.ddp_config = dict(broadcast_buffers=broadcast_buffers, self.ddp_config = dict(broadcast_buffers=broadcast_buffers,
bucket_cap_mb=bucket_cap_mb, bucket_cap_mb=ddp_bucket_cap_mb,
find_unused_parameters=find_unused_parameters, find_unused_parameters=find_unused_parameters,
check_reduction=check_reduction, check_reduction=check_reduction,
gradient_as_bucket_view=gradient_as_bucket_view, gradient_as_bucket_view=gradient_as_bucket_view,
static_graph=static_graph) static_graph=static_graph)
self.zero_config = dict(reduce_bucket_size=zero_bucket_size_in_m * 1024 * 1024,
communication_dtype=communication_dtype,
overlap_communication=overlap_communication,
cpu_offload=cpu_offload,
partition_grad=(self.zero_stage == 2))
self.max_norm = max_norm self.max_norm = max_norm
@property @property
@ -337,15 +373,16 @@ class HybridParallelPlugin(PipelinePluginBase):
model, model,
use_pipeline=self.enable_pipeline_parallelism) use_pipeline=self.enable_pipeline_parallelism)
else: else:
assert self.dp_size > 1, "Please use Zero when data parallel size is greater than 1."
assert self.precision != 'fp32', "Please set precision to 'fp16' or 'bf16' when using ZeRO."
optimizer = HybridParallelZeroOptimizer(optimizer, optimizer = HybridParallelZeroOptimizer(optimizer,
model, model,
use_pipeline=self.enable_pipeline_parallelism, use_pipeline=self.enable_pipeline_parallelism,
partition_grad=(self.zero_stage == 2),
cpu_offload=self.cpu_offload,
dp_process_group=self.dp_group, dp_process_group=self.dp_group,
tp_process_group=self.tp_group, tp_process_group=self.tp_group,
verbose=True, verbose=True,
clip_grad_norm=self.max_norm, clip_grad_norm=self.max_norm,
**self.zero_config,
**self.amp_config) **self.amp_config)
return model, optimizer, criterion, dataloader, lr_scheduler return model, optimizer, criterion, dataloader, lr_scheduler

View File

@ -56,9 +56,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
atol, rtol = 1e-4, 1e-3 atol, rtol = 1e-4, 1e-3
else: else:
atol, rtol = 5e-3, 5e-3 atol, rtol = 5e-3, 5e-3
if stage_manager is None or stage_manager.is_first_stage(): if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0:
#check_weight(bert.embeddings.word_embeddings, sharded_bert.embeddings.word_embeddings, tp_group, atol=1e-5, rtol=1e-3)
#check_weight(bert.encoder.layer[0].attention.self.query, sharded_bert.encoder.layer[0].attention.self.query, tp_group, atol=5e-3, rtol=1e-3)
check_grad(bert, sharded_bert, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False) check_grad(bert, sharded_bert, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False)
check_grad(bert, sharded_bert, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False) check_grad(bert, sharded_bert, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False)
@ -101,6 +99,14 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
'enable_all_optimization': True, 'enable_all_optimization': True,
'use_lazy_init': False, 'use_lazy_init': False,
'precision': 'fp32' 'precision': 'fp32'
}, {
'tp_size': 2,
'pp_size': 1,
'enable_all_optimization': True,
'use_lazy_init': True,
'zero_stage': 2,
'precision': 'fp16',
'initial_scale': 1
}]) }])
def run_bert_test(test_config): def run_bert_test(test_config):

View File

@ -53,7 +53,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
# check grad # check grad
row_layer_for_check = ['h[0].self_attention.query_key_value', 'word_embeddings'] row_layer_for_check = ['h[0].self_attention.query_key_value', 'word_embeddings']
col_layer_for_check = ['h[0].self_attention.dense'] col_layer_for_check = ['h[0].self_attention.dense']
if stage_manager is None or stage_manager.is_first_stage(): if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0:
if test_config['precision'] == 'fp32': if test_config['precision'] == 'fp32':
atol, rtol = 1e-6, 1e-5 atol, rtol = 1e-6, 1e-5
else: else:
@ -101,6 +101,14 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
'enable_all_optimization': True, 'enable_all_optimization': True,
'use_lazy_init': False, 'use_lazy_init': False,
'precision': 'fp32' 'precision': 'fp32'
}, {
'tp_size': 2,
'pp_size': 1,
'enable_all_optimization': True,
'use_lazy_init': True,
'zero_stage': 2,
'precision': 'fp16',
'initial_scale': 1
}]) }])
def run_bloom_test(test_config): def run_bloom_test(test_config):

View File

@ -55,7 +55,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
# check grad # check grad
row_layer_for_check = ['encoder.layers[0].self_attention.query_key_value', 'embedding.word_embeddings'] row_layer_for_check = ['encoder.layers[0].self_attention.query_key_value', 'embedding.word_embeddings']
col_layer_for_check = ['encoder.layers[0].self_attention.dense'] col_layer_for_check = ['encoder.layers[0].self_attention.dense']
if stage_manager is None or stage_manager.is_first_stage(): if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0:
if test_config['precision'] == 'fp32': if test_config['precision'] == 'fp32':
atol, rtol = 1e-6, 1e-3 atol, rtol = 1e-6, 1e-3
else: else:
@ -125,6 +125,14 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
'enable_all_optimization': True, 'enable_all_optimization': True,
'use_lazy_init': False, 'use_lazy_init': False,
'precision': 'fp32' 'precision': 'fp32'
}, {
'tp_size': 2,
'pp_size': 1,
'enable_all_optimization': True,
'use_lazy_init': True,
'zero_stage': 2,
'precision': 'fp16',
'initial_scale': 1
}]) }])
def run_chatglm_test(test_config): def run_chatglm_test(test_config):

View File

@ -56,7 +56,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
row_layer_for_check = ['wte', 'h[0].mlp.c_proj'] row_layer_for_check = ['wte', 'h[0].mlp.c_proj']
# check grad # check grad
if stage_manager is None or stage_manager.is_first_stage(): if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0:
if test_config['precision'] == 'fp32': if test_config['precision'] == 'fp32':
atol, rtol = 1e-4, 1e-3 atol, rtol = 1e-4, 1e-3
else: else:
@ -120,6 +120,14 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
'use_lazy_init': True, 'use_lazy_init': True,
'enable_sequence_parallelism': True, 'enable_sequence_parallelism': True,
'precision': 'fp32', 'precision': 'fp32',
}, {
'tp_size': 2,
'pp_size': 1,
'enable_all_optimization': True,
'use_lazy_init': True,
'zero_stage': 2,
'precision': 'fp16',
'initial_scale': 1
}]) }])
@clear_cache_before_run() @clear_cache_before_run()
def run_gpt2_test(test_config): def run_gpt2_test(test_config):

View File

@ -60,7 +60,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
# check grad # check grad
row_layer_for_check = ['layers[0].self_attn.q_proj', 'embed_tokens'] row_layer_for_check = ['layers[0].self_attn.q_proj', 'embed_tokens']
col_layer_for_check = ['layers[0].self_attn.o_proj'] col_layer_for_check = ['layers[0].self_attn.o_proj']
if stage_manager is None or stage_manager.is_first_stage(): if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0:
if test_config['precision'] == 'fp32': if test_config['precision'] == 'fp32':
atol, rtol = 1e-6, 1e-4 atol, rtol = 1e-6, 1e-4
else: else:
@ -135,6 +135,14 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
'enable_all_optimization': True, 'enable_all_optimization': True,
'use_lazy_init': False, 'use_lazy_init': False,
'precision': 'fp32' 'precision': 'fp32'
}, {
'tp_size': 2,
'pp_size': 1,
'enable_all_optimization': True,
'use_lazy_init': True,
'zero_stage': 2,
'precision': 'fp16',
'initial_scale': 1
}]) }])
def run_llama_test(test_config): def run_llama_test(test_config):

View File

@ -58,7 +58,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
# check grad # check grad
row_layer_for_check = ['decoder.layers[0].self_attn.q_proj', 'decoder.embed_tokens'] # 'decoder.embed_tokens' row_layer_for_check = ['decoder.layers[0].self_attn.q_proj', 'decoder.embed_tokens'] # 'decoder.embed_tokens'
col_layer_for_check = ['decoder.layers[0].self_attn.out_proj'] col_layer_for_check = ['decoder.layers[0].self_attn.out_proj']
if stage_manager is None or stage_manager.is_first_stage(): if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0:
if test_config['precision'] == 'fp32': if test_config['precision'] == 'fp32':
atol, rtol = 1e-6, 1e-3 atol, rtol = 1e-6, 1e-3
else: else:
@ -127,6 +127,14 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
'enable_all_optimization': True, 'enable_all_optimization': True,
'use_lazy_init': False, 'use_lazy_init': False,
'precision': 'fp32' 'precision': 'fp32'
}, {
'tp_size': 2,
'pp_size': 1,
'enable_all_optimization': True,
'use_lazy_init': True,
'zero_stage': 2,
'precision': 'fp16',
'initial_scale': 1
}]) }])
def run_opt_test(test_config): def run_opt_test(test_config):

View File

@ -55,12 +55,12 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
row_layer_for_check = ['shared', 'encoder.block[0].layer[0].SelfAttention.q'] row_layer_for_check = ['shared', 'encoder.block[0].layer[0].SelfAttention.q']
# check weights and gradients # check grad
if test_config['precision'] == 'fp32': if test_config['precision'] == 'fp32':
atol, rtol = 1e-5, 1e-3 atol, rtol = 1e-5, 1e-3
else: else:
atol, rtol = 5e-3, 5e-3 atol, rtol = 5e-3, 5e-3
if stage_manager is None or stage_manager.is_first_stage(): if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0:
check_grad(t5, sharded_t5, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0) check_grad(t5, sharded_t5, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0)
# check weights after optimizer.step() # check weights after optimizer.step()
@ -110,6 +110,14 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
'enable_all_optimization': True, 'enable_all_optimization': True,
'use_lazy_init': False, 'use_lazy_init': False,
'precision': 'fp32' 'precision': 'fp32'
}, {
'tp_size': 2,
'pp_size': 1,
'enable_all_optimization': True,
'use_lazy_init': True,
'zero_stage': 2,
'precision': 'fp16',
'initial_scale': 1
}]) }])
@clear_cache_before_run() @clear_cache_before_run()
def run_t5_test(test_config): def run_t5_test(test_config):

View File

@ -55,7 +55,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
# check grad # check grad
row_layer_for_check = ['encoder.layer[0].attention.attention.query', 'embeddings.patch_embeddings.projection'] row_layer_for_check = ['encoder.layer[0].attention.attention.query', 'embeddings.patch_embeddings.projection']
col_layer_for_check = ['encoder.layer[0].attention.output.dense'] col_layer_for_check = ['encoder.layer[0].attention.output.dense']
if stage_manager is None or stage_manager.is_first_stage(): if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0:
if test_config['precision'] == 'fp32': if test_config['precision'] == 'fp32':
atol, rtol = 1e-5, 1e-3 atol, rtol = 1e-5, 1e-3
else: else:
@ -124,6 +124,14 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
'enable_all_optimization': True, 'enable_all_optimization': True,
'use_lazy_init': False, 'use_lazy_init': False,
'precision': 'fp32' 'precision': 'fp32'
}, {
'tp_size': 2,
'pp_size': 1,
'enable_all_optimization': True,
'use_lazy_init': False,
'zero_stage': 2,
'precision': 'fp16',
'initial_scale': 1
}]) }])
def run_vit_test(test_config): def run_vit_test(test_config):