mirror of https://github.com/hpcaitech/ColossalAI
[shardformer] support tp+zero for shardformer (#4472)
* support tp+zero/input type cast for hybridplugin * add tp+zero tests * fix bucket argumentspull/4484/head
parent
8739aa7fa0
commit
1c7df566e2
|
@ -1,5 +1,6 @@
|
|||
import random
|
||||
from contextlib import nullcontext
|
||||
from functools import partial
|
||||
from typing import Any, Callable, Iterator, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
|
@ -10,6 +11,7 @@ from torch.nn import Module, SyncBatchNorm
|
|||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
||||
from torch.utils._pytree import tree_map
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
|
||||
|
@ -27,32 +29,49 @@ from .pp_plugin_base import PipelinePluginBase
|
|||
DP_AXIS, PP_AXIS, TP_AXIS = 0, 1, 2
|
||||
|
||||
|
||||
def _convert_floating_point(x, dtype: torch.dtype = torch.float16):
|
||||
if isinstance(x, torch.Tensor) and torch.is_floating_point(x):
|
||||
return x.to(dtype)
|
||||
return x
|
||||
|
||||
|
||||
class HybridParallelModule(ModelWrapper):
|
||||
|
||||
def __init__(self, module: Module, precision: str, shard_config: ShardConfig, dp_group: ProcessGroup, use_ddp: bool,
|
||||
ddp_config: dict) -> None:
|
||||
|
||||
self.stage_manager = shard_config.pipeline_stage_manager
|
||||
self.dp_group = dp_group
|
||||
|
||||
shardformer = ShardFormer(shard_config)
|
||||
module, self.shared_params = shardformer.optimize(module)
|
||||
# TODO(ver217): add input type cast
|
||||
|
||||
# setting process groups for shared parameters
|
||||
self.shared_param_process_groups = []
|
||||
for shared_param in self.shared_params:
|
||||
if len(shared_param) > 0:
|
||||
self.shared_param_process_groups.append(
|
||||
self.stage_manager.init_process_group_by_stages(list(shared_param.keys())))
|
||||
|
||||
# setting mixed_precision
|
||||
self.mixed_precision = None
|
||||
if precision == 'fp16':
|
||||
module = module.half().cuda()
|
||||
self.mixed_precision = torch.float16
|
||||
elif precision == 'bf16':
|
||||
module = module.to(dtype=torch.bfloat16).cuda()
|
||||
else:
|
||||
module = module.cuda() # train without AMP
|
||||
self.mixed_precision = torch.bfloat16
|
||||
if self.mixed_precision is not None:
|
||||
module = module.to(self.mixed_precision)
|
||||
module = module.cuda()
|
||||
|
||||
# setting input type cast when using mixed precision
|
||||
self.convert_fn = None
|
||||
if self.mixed_precision is not None:
|
||||
self.convert_fn = partial(_convert_floating_point, dtype=self.mixed_precision)
|
||||
|
||||
# setting ddp configs
|
||||
if use_ddp:
|
||||
|
||||
# convert model to sync bn
|
||||
module = SyncBatchNorm.convert_sync_batchnorm(module, dp_group)
|
||||
|
||||
# wrap the model with PyTorch DDP
|
||||
module = DDP(module, process_group=dp_group, **ddp_config)
|
||||
|
||||
|
@ -78,6 +97,12 @@ class HybridParallelModule(ModelWrapper):
|
|||
dist.all_reduce(p.grad, group=self.dp_group)
|
||||
p.grad.div_(self.dp_group.size())
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
if self.convert_fn is not None:
|
||||
args = tree_map(self.convert_fn, args)
|
||||
kwargs = tree_map(self.convert_fn, kwargs)
|
||||
return super().forward(*args, **kwargs)
|
||||
|
||||
def unwrap(self):
|
||||
module = super().unwrap()
|
||||
if isinstance(module, DDP):
|
||||
|
@ -180,7 +205,6 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||
Defaults to 'fp16'.
|
||||
zero_stage (int, optional): The stage of ZeRO for data parallelism. Can only be choosed from [0, 1, 2].
|
||||
When set to 0, ZeRO will not be used. Defaults to 0.
|
||||
cpu_offload (bool, optional): Whether to open cpu_offload when using ZeRO. Defaults to False.
|
||||
enable_all_optimization (bool, optional): Whether to switch on all the optimizations supported by Shardformer.
|
||||
Currently all the optimization methods include fused normalization, flash attention and JIT.
|
||||
Defaults to False.
|
||||
|
@ -196,12 +220,16 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||
hysteresis (int, optional): The number of overflows before decreasing loss scale when using AMP. Defaults to 2.
|
||||
max_scale (float, optional): The maximum loss scale of AMP. Defaults to 2**32.
|
||||
max_norm (float, optional): Maximum norm for gradient clipping. Defaults to 0.
|
||||
broadcast_buffers (bool, optional): Whether to broadcast buffers in the beginning of training. Only for usage of DDP. Defaults to True.
|
||||
bucket_cap_mb (int, optional): The bucket size in MB. Only for usage of DDP. Defaults to 25.
|
||||
find_unused_parameters (bool, optional): Whether to find unused parameters. Only for usage of DDP. Defaults to False.
|
||||
check_reduction (bool, optional): Whether to check reduction. Only for usage of DDP. Defaults to False.
|
||||
gradient_as_bucket_view (bool, optional): Whether to use gradient as bucket view. Only for usage of DDP. Defaults to False.
|
||||
static_graph (bool, optional): Whether to use static graph. Only for usage of DDP. Defaults to False.
|
||||
broadcast_buffers (bool, optional): Whether to broadcast buffers in the beginning of training when using DDP. Defaults to True.
|
||||
ddp_bucket_cap_mb (int, optional): The bucket size in MB when using DDP. Defaults to 25.
|
||||
find_unused_parameters (bool, optional): Whether to find unused parameters when using DDP. Defaults to False.
|
||||
check_reduction (bool, optional): Whether to check reduction when using DDP. Defaults to False.
|
||||
gradient_as_bucket_view (bool, optional): Whether to use gradient as bucket view when using DDP. Defaults to False.
|
||||
static_graph (bool, optional): Whether to use static graph when using DDP. Defaults to False.
|
||||
zero_bucket_size_in_m (int, optional): Gradient reduce bucket size in million elements when using ZeRO. Defaults to 12.
|
||||
cpu_offload (bool, optional): Whether to open cpu_offload when using ZeRO. Defaults to False.
|
||||
communication_dtype (torch.dtype, optional): Communication dtype when using ZeRO. If not specified, the dtype of param will be used. Defaults to None.
|
||||
overlap_communication (bool, optional): Whether to overlap communication and computation when using ZeRO. Defaults to True.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
|
@ -209,7 +237,6 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||
pp_size: int,
|
||||
precision: str = 'fp16',
|
||||
zero_stage: int = 0,
|
||||
cpu_offload: bool = False,
|
||||
enable_all_optimization: bool = False,
|
||||
enable_fused_normalization: bool = False,
|
||||
enable_flash_attention: bool = False,
|
||||
|
@ -224,12 +251,16 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||
hysteresis: int = 2,
|
||||
max_scale: float = 2**32,
|
||||
max_norm: float = 0,
|
||||
broadcast_buffers=True,
|
||||
bucket_cap_mb=25,
|
||||
find_unused_parameters=False,
|
||||
check_reduction=False,
|
||||
gradient_as_bucket_view=False,
|
||||
static_graph=False) -> None:
|
||||
broadcast_buffers: bool = True,
|
||||
ddp_bucket_cap_mb: int = 25,
|
||||
find_unused_parameters: bool = False,
|
||||
check_reduction: bool = False,
|
||||
gradient_as_bucket_view: bool = False,
|
||||
static_graph: bool = False,
|
||||
zero_bucket_size_in_m: int = 12,
|
||||
cpu_offload: bool = False,
|
||||
communication_dtype: Optional[torch.dtype] = None,
|
||||
overlap_communication: bool = True) -> None:
|
||||
|
||||
super().__init__()
|
||||
assert dist.get_world_size() % (
|
||||
|
@ -239,8 +270,6 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||
if enable_sequence_parallelism:
|
||||
assert tp_size > 1, 'Sequence parallelism must be enabled when using tensor parallelism'
|
||||
|
||||
# TODO(ver217): support zero
|
||||
assert zero_stage == 0, 'zero is not support yet'
|
||||
self.tp_size = tp_size
|
||||
self.pp_size = pp_size
|
||||
self.dp_size = dist.get_world_size() // (tp_size * pp_size)
|
||||
|
@ -282,11 +311,18 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||
)
|
||||
|
||||
self.ddp_config = dict(broadcast_buffers=broadcast_buffers,
|
||||
bucket_cap_mb=bucket_cap_mb,
|
||||
bucket_cap_mb=ddp_bucket_cap_mb,
|
||||
find_unused_parameters=find_unused_parameters,
|
||||
check_reduction=check_reduction,
|
||||
gradient_as_bucket_view=gradient_as_bucket_view,
|
||||
static_graph=static_graph)
|
||||
|
||||
self.zero_config = dict(reduce_bucket_size=zero_bucket_size_in_m * 1024 * 1024,
|
||||
communication_dtype=communication_dtype,
|
||||
overlap_communication=overlap_communication,
|
||||
cpu_offload=cpu_offload,
|
||||
partition_grad=(self.zero_stage == 2))
|
||||
|
||||
self.max_norm = max_norm
|
||||
|
||||
@property
|
||||
|
@ -337,15 +373,16 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||
model,
|
||||
use_pipeline=self.enable_pipeline_parallelism)
|
||||
else:
|
||||
assert self.dp_size > 1, "Please use Zero when data parallel size is greater than 1."
|
||||
assert self.precision != 'fp32', "Please set precision to 'fp16' or 'bf16' when using ZeRO."
|
||||
optimizer = HybridParallelZeroOptimizer(optimizer,
|
||||
model,
|
||||
use_pipeline=self.enable_pipeline_parallelism,
|
||||
partition_grad=(self.zero_stage == 2),
|
||||
cpu_offload=self.cpu_offload,
|
||||
dp_process_group=self.dp_group,
|
||||
tp_process_group=self.tp_group,
|
||||
verbose=True,
|
||||
clip_grad_norm=self.max_norm,
|
||||
**self.zero_config,
|
||||
**self.amp_config)
|
||||
return model, optimizer, criterion, dataloader, lr_scheduler
|
||||
|
||||
|
|
|
@ -56,9 +56,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
atol, rtol = 1e-4, 1e-3
|
||||
else:
|
||||
atol, rtol = 5e-3, 5e-3
|
||||
if stage_manager is None or stage_manager.is_first_stage():
|
||||
#check_weight(bert.embeddings.word_embeddings, sharded_bert.embeddings.word_embeddings, tp_group, atol=1e-5, rtol=1e-3)
|
||||
#check_weight(bert.encoder.layer[0].attention.self.query, sharded_bert.encoder.layer[0].attention.self.query, tp_group, atol=5e-3, rtol=1e-3)
|
||||
if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0:
|
||||
check_grad(bert, sharded_bert, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False)
|
||||
check_grad(bert, sharded_bert, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False)
|
||||
|
||||
|
@ -101,6 +99,14 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
'enable_all_optimization': True,
|
||||
'use_lazy_init': False,
|
||||
'precision': 'fp32'
|
||||
}, {
|
||||
'tp_size': 2,
|
||||
'pp_size': 1,
|
||||
'enable_all_optimization': True,
|
||||
'use_lazy_init': True,
|
||||
'zero_stage': 2,
|
||||
'precision': 'fp16',
|
||||
'initial_scale': 1
|
||||
}])
|
||||
def run_bert_test(test_config):
|
||||
|
||||
|
|
|
@ -53,7 +53,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
# check grad
|
||||
row_layer_for_check = ['h[0].self_attention.query_key_value', 'word_embeddings']
|
||||
col_layer_for_check = ['h[0].self_attention.dense']
|
||||
if stage_manager is None or stage_manager.is_first_stage():
|
||||
if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0:
|
||||
if test_config['precision'] == 'fp32':
|
||||
atol, rtol = 1e-6, 1e-5
|
||||
else:
|
||||
|
@ -101,6 +101,14 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
'enable_all_optimization': True,
|
||||
'use_lazy_init': False,
|
||||
'precision': 'fp32'
|
||||
}, {
|
||||
'tp_size': 2,
|
||||
'pp_size': 1,
|
||||
'enable_all_optimization': True,
|
||||
'use_lazy_init': True,
|
||||
'zero_stage': 2,
|
||||
'precision': 'fp16',
|
||||
'initial_scale': 1
|
||||
}])
|
||||
def run_bloom_test(test_config):
|
||||
|
||||
|
|
|
@ -55,7 +55,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
# check grad
|
||||
row_layer_for_check = ['encoder.layers[0].self_attention.query_key_value', 'embedding.word_embeddings']
|
||||
col_layer_for_check = ['encoder.layers[0].self_attention.dense']
|
||||
if stage_manager is None or stage_manager.is_first_stage():
|
||||
if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0:
|
||||
if test_config['precision'] == 'fp32':
|
||||
atol, rtol = 1e-6, 1e-3
|
||||
else:
|
||||
|
@ -125,6 +125,14 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
'enable_all_optimization': True,
|
||||
'use_lazy_init': False,
|
||||
'precision': 'fp32'
|
||||
}, {
|
||||
'tp_size': 2,
|
||||
'pp_size': 1,
|
||||
'enable_all_optimization': True,
|
||||
'use_lazy_init': True,
|
||||
'zero_stage': 2,
|
||||
'precision': 'fp16',
|
||||
'initial_scale': 1
|
||||
}])
|
||||
def run_chatglm_test(test_config):
|
||||
|
||||
|
|
|
@ -56,7 +56,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
row_layer_for_check = ['wte', 'h[0].mlp.c_proj']
|
||||
|
||||
# check grad
|
||||
if stage_manager is None or stage_manager.is_first_stage():
|
||||
if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0:
|
||||
if test_config['precision'] == 'fp32':
|
||||
atol, rtol = 1e-4, 1e-3
|
||||
else:
|
||||
|
@ -120,6 +120,14 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
'use_lazy_init': True,
|
||||
'enable_sequence_parallelism': True,
|
||||
'precision': 'fp32',
|
||||
}, {
|
||||
'tp_size': 2,
|
||||
'pp_size': 1,
|
||||
'enable_all_optimization': True,
|
||||
'use_lazy_init': True,
|
||||
'zero_stage': 2,
|
||||
'precision': 'fp16',
|
||||
'initial_scale': 1
|
||||
}])
|
||||
@clear_cache_before_run()
|
||||
def run_gpt2_test(test_config):
|
||||
|
|
|
@ -60,7 +60,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
# check grad
|
||||
row_layer_for_check = ['layers[0].self_attn.q_proj', 'embed_tokens']
|
||||
col_layer_for_check = ['layers[0].self_attn.o_proj']
|
||||
if stage_manager is None or stage_manager.is_first_stage():
|
||||
if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0:
|
||||
if test_config['precision'] == 'fp32':
|
||||
atol, rtol = 1e-6, 1e-4
|
||||
else:
|
||||
|
@ -135,6 +135,14 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
'enable_all_optimization': True,
|
||||
'use_lazy_init': False,
|
||||
'precision': 'fp32'
|
||||
}, {
|
||||
'tp_size': 2,
|
||||
'pp_size': 1,
|
||||
'enable_all_optimization': True,
|
||||
'use_lazy_init': True,
|
||||
'zero_stage': 2,
|
||||
'precision': 'fp16',
|
||||
'initial_scale': 1
|
||||
}])
|
||||
def run_llama_test(test_config):
|
||||
|
||||
|
|
|
@ -58,7 +58,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
# check grad
|
||||
row_layer_for_check = ['decoder.layers[0].self_attn.q_proj', 'decoder.embed_tokens'] # 'decoder.embed_tokens'
|
||||
col_layer_for_check = ['decoder.layers[0].self_attn.out_proj']
|
||||
if stage_manager is None or stage_manager.is_first_stage():
|
||||
if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0:
|
||||
if test_config['precision'] == 'fp32':
|
||||
atol, rtol = 1e-6, 1e-3
|
||||
else:
|
||||
|
@ -127,6 +127,14 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
'enable_all_optimization': True,
|
||||
'use_lazy_init': False,
|
||||
'precision': 'fp32'
|
||||
}, {
|
||||
'tp_size': 2,
|
||||
'pp_size': 1,
|
||||
'enable_all_optimization': True,
|
||||
'use_lazy_init': True,
|
||||
'zero_stage': 2,
|
||||
'precision': 'fp16',
|
||||
'initial_scale': 1
|
||||
}])
|
||||
def run_opt_test(test_config):
|
||||
|
||||
|
|
|
@ -55,12 +55,12 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
|
||||
row_layer_for_check = ['shared', 'encoder.block[0].layer[0].SelfAttention.q']
|
||||
|
||||
# check weights and gradients
|
||||
# check grad
|
||||
if test_config['precision'] == 'fp32':
|
||||
atol, rtol = 1e-5, 1e-3
|
||||
else:
|
||||
atol, rtol = 5e-3, 5e-3
|
||||
if stage_manager is None or stage_manager.is_first_stage():
|
||||
if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0:
|
||||
check_grad(t5, sharded_t5, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0)
|
||||
|
||||
# check weights after optimizer.step()
|
||||
|
@ -110,6 +110,14 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
'enable_all_optimization': True,
|
||||
'use_lazy_init': False,
|
||||
'precision': 'fp32'
|
||||
}, {
|
||||
'tp_size': 2,
|
||||
'pp_size': 1,
|
||||
'enable_all_optimization': True,
|
||||
'use_lazy_init': True,
|
||||
'zero_stage': 2,
|
||||
'precision': 'fp16',
|
||||
'initial_scale': 1
|
||||
}])
|
||||
@clear_cache_before_run()
|
||||
def run_t5_test(test_config):
|
||||
|
|
|
@ -55,7 +55,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
# check grad
|
||||
row_layer_for_check = ['encoder.layer[0].attention.attention.query', 'embeddings.patch_embeddings.projection']
|
||||
col_layer_for_check = ['encoder.layer[0].attention.output.dense']
|
||||
if stage_manager is None or stage_manager.is_first_stage():
|
||||
if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0:
|
||||
if test_config['precision'] == 'fp32':
|
||||
atol, rtol = 1e-5, 1e-3
|
||||
else:
|
||||
|
@ -124,6 +124,14 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
'enable_all_optimization': True,
|
||||
'use_lazy_init': False,
|
||||
'precision': 'fp32'
|
||||
}, {
|
||||
'tp_size': 2,
|
||||
'pp_size': 1,
|
||||
'enable_all_optimization': True,
|
||||
'use_lazy_init': False,
|
||||
'zero_stage': 2,
|
||||
'precision': 'fp16',
|
||||
'initial_scale': 1
|
||||
}])
|
||||
def run_vit_test(test_config):
|
||||
|
||||
|
|
Loading…
Reference in New Issue