From 0ceec8f9a9401b6ed10c916fcf8bf9c60fceefd9 Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Tue, 1 Aug 2023 17:29:09 +0800 Subject: [PATCH] [pipeline] support fp32 for HybridPlugin/merge shardformer test and pipeline test into one file (#4354) * add naive optimizer for 3DPlugin/refactor gpt2 shardformer test * merge tests of PP/DP/TP combinations into one test file * fix bug when sync grad for dp in HybridPlugin * update supported precisions for 3DPlugin/fix bug when shifting tp_degree * improve the passing of lazy_init * modify lazy_init/use sync_shared_params --- .../naive_amp/mixed_precision_optimizer.py | 2 +- .../booster/plugin/hybrid_parallel_plugin.py | 37 +++- .../shardformer/layer/qkv_fused_linear.py | 4 +- colossalai/tensor/d_tensor/api.py | 5 + tests/kit/model_zoo/transformers/gpt.py | 3 +- .../test_model/test_pure_pipeline.py | 1 - .../test_model/test_shard_gpt2.py | 205 +++++++++++++----- .../test_model/test_shard_gpt2_pipeline.py | 72 ------ 8 files changed, 187 insertions(+), 142 deletions(-) delete mode 100644 tests/test_shardformer/test_model/test_shard_gpt2_pipeline.py diff --git a/colossalai/amp/naive_amp/mixed_precision_optimizer.py b/colossalai/amp/naive_amp/mixed_precision_optimizer.py index d4183be3f..626a00c96 100644 --- a/colossalai/amp/naive_amp/mixed_precision_optimizer.py +++ b/colossalai/amp/naive_amp/mixed_precision_optimizer.py @@ -134,7 +134,7 @@ class MixedPrecisionOptimizer(OptimizerWrapper): working_param = self.master_to_working_map[p] if p is working_param: continue - if working_param.grad is None: + if working_param.grad is not None: p.grad = working_param.grad.data.float() working_param.grad = None total_norm = self._compute_grad_norm() diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 37badb613..35a88d1e8 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -42,6 +42,8 @@ class HybridParallelModule(ModelWrapper): module = module.half().cuda() elif precision == 'bf16': module = module.to(dtype=torch.bfloat16).cuda() + else: + module = module.cuda() # train without AMP # TODO(ver217): support TP+DP super().__init__(module) @@ -61,6 +63,7 @@ class HybridParallelModule(ModelWrapper): for p in self.module.parameters(): if p.grad is not None: dist.all_reduce(p.grad, group=self.dp_group) + p.grad.div_(self.dp_group.size()) def init_pipeline_optimizer(optim: Optimizer, model: Module): @@ -72,7 +75,15 @@ def init_pipeline_optimizer(optim: Optimizer, model: Module): optim.__setstate__({'param_groups': new_param_groups}) -class HybridParallelOptimizer(MixedPrecisionOptimizer): +class HybridParallelNaiveOptimizer(OptimizerWrapper): + + def __init__(self, optim: Optimizer, model: Module, use_pipeline: bool): + if use_pipeline: + init_pipeline_optimizer(optim, model) + super().__init__(optim) + + +class HybridParallelAMPOptimizer(MixedPrecisionOptimizer): def __init__(self, optim: Optimizer, @@ -192,7 +203,7 @@ class HybridParallelPlugin(PipelinePluginBase): return ['cuda'] def supported_precisions(self) -> List[str]: - return ['fp16', 'bf16'] + return ['fp16', 'bf16', 'fp32'] def control_device(self) -> bool: return True @@ -218,12 +229,17 @@ class HybridParallelPlugin(PipelinePluginBase): model = HybridParallelModule(model, self.precision, self.shard_config, self.dp_group) if optimizer is not None and not isinstance(optimizer, OptimizerWrapper): if self.zero_stage == 0: - optimizer = HybridParallelOptimizer(optimizer, - model, - use_pipeline=self.enable_pipeline_parallelism, - precision=self.precision, - max_norm=self.max_norm, - **self.amp_config) + if self.precision in ['fp16', 'bf16']: + optimizer = HybridParallelAMPOptimizer(optimizer, + model, + use_pipeline=self.enable_pipeline_parallelism, + precision=self.precision, + max_norm=self.max_norm, + **self.amp_config) + else: + optimizer = HybridParallelNaiveOptimizer(optimizer, + model, + use_pipeline=self.enable_pipeline_parallelism) else: optimizer = HybridParallelZeroOptimizer(optimizer, model, @@ -241,7 +257,8 @@ class HybridParallelPlugin(PipelinePluginBase): data_iter: Iterator, model: HybridParallelModule, criterion: Callable[[Any, Any], torch.Tensor], - optimizer: Union[HybridParallelOptimizer, HybridParallelZeroOptimizer], + optimizer: Union[HybridParallelNaiveOptimizer, HybridParallelAMPOptimizer, + HybridParallelZeroOptimizer], return_loss: bool = True, return_outputs: bool = False) -> dict: assert self.enable_pipeline_parallelism, 'pipeline parallelism is not enabled' @@ -250,7 +267,7 @@ class HybridParallelPlugin(PipelinePluginBase): with ctx: outputs = self.schedule.forward_backward_step(model, optimizer, data_iter, criterion, return_loss, return_outputs) - # model.sync_shared_params() + model.sync_shared_params() if isinstance(optimizer, HybridParallelZeroOptimizer): optimizer.sync_grad() else: diff --git a/colossalai/shardformer/layer/qkv_fused_linear.py b/colossalai/shardformer/layer/qkv_fused_linear.py index bcefcf058..3c47c0b11 100644 --- a/colossalai/shardformer/layer/qkv_fused_linear.py +++ b/colossalai/shardformer/layer/qkv_fused_linear.py @@ -456,12 +456,12 @@ class GPT2FusedLinearConv1D_Row(ParallelModule): if self.parallel_input: assert input_.shape[-1] == self.weight.shape[0], \ 'Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.'.format( - input_.shape, self.weight.shape, self.weight.shape[-1]) + input_.shape, self.weight.shape, self.weight.shape[0]) input_ = input_ else: assert divide(input_.shape[-1], self.num_partitions) == self.weight.shape[0], \ 'Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.'.format( - input_.shape, self.weight.shape, self.weight.shape[-1] * self.num_partitions) + input_.shape, self.weight.shape, self.weight.shape[0] * self.num_partitions) input_ = split_forward_gather_backward(input_, dim=-1, process_group=self.process_group) if self.stream_chunk_num > 1: diff --git a/colossalai/tensor/d_tensor/api.py b/colossalai/tensor/d_tensor/api.py index 32182faf6..9848e4ca4 100644 --- a/colossalai/tensor/d_tensor/api.py +++ b/colossalai/tensor/d_tensor/api.py @@ -16,6 +16,11 @@ from .sharding_spec import ShardingSpec layout_converter = LayoutConverter() +def clear_layout_converter(): + global layout_converter + layout_converter.cached_solution.clear() + + def is_distributed_tensor(tensor: torch.Tensor) -> bool: """ Check whether the given tensor is a distributed tensor. diff --git a/tests/kit/model_zoo/transformers/gpt.py b/tests/kit/model_zoo/transformers/gpt.py index e447b7001..fcde75abd 100644 --- a/tests/kit/model_zoo/transformers/gpt.py +++ b/tests/kit/model_zoo/transformers/gpt.py @@ -70,7 +70,8 @@ config = transformers.GPT2Config(n_layer=2, resid_pdrop=0, summary_first_dropout=0, hidden_dropout=0, - problem_type="single_label_classification") + problem_type="single_label_classification", + pad_token_id=50256) # register the following models model_zoo.register(name='transformers_gpt', diff --git a/tests/test_shardformer/test_model/test_pure_pipeline.py b/tests/test_shardformer/test_model/test_pure_pipeline.py index 576e6473b..31e76ef51 100644 --- a/tests/test_shardformer/test_model/test_pure_pipeline.py +++ b/tests/test_shardformer/test_model/test_pure_pipeline.py @@ -160,7 +160,6 @@ def check_llama(rank, world_size, port): run_llama_test() -@pytest.mark.skip('This test will fail') @pytest.mark.dist @rerun_if_address_is_in_use() @clear_cache_before_run() diff --git a/tests/test_shardformer/test_model/test_shard_gpt2.py b/tests/test_shardformer/test_model/test_shard_gpt2.py index 99451b403..eae4f2ffb 100644 --- a/tests/test_shardformer/test_model/test_shard_gpt2.py +++ b/tests/test_shardformer/test_model/test_shard_gpt2.py @@ -1,85 +1,180 @@ +import copy +from contextlib import nullcontext + import pytest import torch +from torch import distributed as dist +from torch.optim import Adam import colossalai +from colossalai.booster import Booster +from colossalai.booster.plugin import HybridParallelPlugin +from colossalai.lazy.lazy_init import LazyInitContext from colossalai.logging import disable_existing_loggers -from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor -from colossalai.testing import ( - assert_hf_output_close, - clear_cache_before_run, - parameterize, - rerun_if_address_is_in_use, - spawn, +from colossalai.tensor.d_tensor.api import ( + clear_layout_converter, + is_customized_distributed_tensor, + is_distributed_tensor, ) +from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn from tests.kit.model_zoo import model_zoo from tests.test_shardformer.test_model._utils import build_model, check_state_dict, run_forward -def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn): - # check forward - org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn, - output_transform_fn, loss_fn) - assert_hf_output_close(org_output, shard_output, ignore_keys=['past_key_values']) +def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config): - # do backward + use_lazy_init = False + if 'use_lazy_init' in test_config: + use_lazy_init = test_config.pop('use_lazy_init') + + if use_lazy_init: + ctx = LazyInitContext() + else: + ctx = nullcontext() + + # prepare booster + plugin = HybridParallelPlugin(**test_config) + booster = Booster(plugin=plugin) + stage_manager = plugin.stage_manager + + # prepare models and optimizers + with ctx: + org_model = model_fn().cuda() + sharded_model = copy.deepcopy(org_model) + + if use_lazy_init: + org_model = ctx.materialize(org_model) + + org_optimizer = Adam(org_model.parameters(), lr=1e-3) + sharded_optimizer = Adam(sharded_model.parameters(), lr=1e-3) + criterion = loss_fn + + sharded_model, sharded_optimizer, criterion, _, _ = booster.boost(sharded_model, sharded_optimizer, criterion) + + def _criterion(outputs, inputs): + outputs = output_transform_fn(outputs) + loss = criterion(outputs) + return loss + + # do forward and backward + data = data_gen_fn() + sharded_model.train() + if stage_manager: + data = { + k: v.to('cuda').repeat(4, 1) if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v + for k, v in data.items() + } + data_iter = iter([data]) + sharded_output = booster.execute_pipeline(data_iter, + sharded_model, + _criterion, + sharded_optimizer, + return_loss=True, + return_outputs=True) + sharded_loss = sharded_output['loss'] + else: + data = {k: v.cuda() for k, v in data.items()} + sharded_output = sharded_model(**data) + sharded_loss = criterion(sharded_output) + sharded_loss.backward() + + org_model.train() + org_output = org_model(**data) + org_loss = criterion(org_output) org_loss.backward() - shard_loss.backward() - assert torch.allclose(org_loss, shard_loss, - atol=1e-5), f"shard model loss is not equal to origin model loss\n{org_loss}\n{shard_loss}" + if stage_manager is None or stage_manager.is_last_stage(): + + # check last hidden state + if org_model.__class__.__name__ == 'GPT2Model': + org_hidden_state = org_output.last_hidden_state + + if stage_manager is None: + sharded_hidden_state = sharded_output.last_hidden_state + + if stage_manager and stage_manager.is_last_stage(): + sharded_hidden_state = torch.cat([output.last_hidden_state for output in sharded_output['outputs']], + dim=0) + + assert torch.allclose(org_hidden_state, sharded_hidden_state, atol=1e-5, rtol=1e-3), \ + f"shard model's output hidden state is not equal to origin model's last hidden state\n{org_hidden_state}\n{sharded_hidden_state}" + + # check loss + assert torch.allclose(org_loss, sharded_loss, atol=1e-5, rtol=1e-3), \ + f"shard model loss is not equal to origin model loss\n{org_loss}\n{sharded_loss}" # unwrap model if org_model.__class__.__name__ == 'GPT2Model': org_model = org_model - sharded_model = sharded_model + sharded_model = sharded_model.unwrap() else: org_model = org_model.transformer - sharded_model = sharded_model.transformer + sharded_model = sharded_model.unwrap().transformer - # check mlp grad - org_grad = org_model.h[0].mlp.c_fc.weight.grad - shard_grad = sharded_model.h[0].mlp.c_fc.weight.grad - shard_weight = sharded_model.h[0].mlp.c_fc.weight + # check weights and gradients + if stage_manager is None or stage_manager.is_first_stage(): - if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight): - shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)] - shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad) - all_shard_grad = torch.cat(shard_grad_list, dim=1) - else: - all_shard_grad = shard_grad - assert torch.allclose( - org_grad, all_shard_grad, - atol=1e-5), f"shard model grad is not equal to origin model grad\n{org_grad}\n{all_shard_grad}" + shard_weight = sharded_model.h[0].mlp.c_fc.weight + org_grad = org_model.h[0].mlp.c_fc.weight.grad + shard_grad = sharded_model.h[0].mlp.c_fc.weight.grad - # check embedding weights - org_grad = org_model.wte.weight.grad - shard_grad = sharded_model.wte.weight.grad - shard_weight = sharded_model.wte.weight + if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight): + shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(plugin.tp_size)] + dist.all_gather(shard_grad_list, shard_grad, plugin.tp_group) + shard_grad = torch.cat(shard_grad_list, dim=1) + + assert torch.allclose(org_grad, shard_grad, atol=1e-5, rtol=1e-3), \ + f"shard model grad is not equal to origin model grad\n{org_grad}\n{shard_grad}" + + # check weights after optimizer.step() + org_optimizer.step() + sharded_optimizer.step() + if stage_manager is None or stage_manager.is_first_stage(): + + org_weight = org_model.h[0].mlp.c_fc.weight + shard_weight = sharded_model.h[0].mlp.c_fc.weight + + if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight): + shard_weight_list = [torch.zeros([*shard_weight.shape]).to('cuda') for _ in range(plugin.tp_size)] + dist.all_gather(shard_weight_list, shard_weight, plugin.tp_group) + shard_weight = torch.cat(shard_weight_list, dim=1) + + assert torch.allclose(org_weight, shard_weight, atol=5e-3, rtol=1e-3), \ + f"shard model weight is not equal to origin model weight\n{org_weight}\n{shard_weight}" - if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight): - shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)] - shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad) - all_shard_grad = torch.cat(shard_grad_list, dim=0) - else: - all_shard_grad = shard_grad - assert torch.allclose( - org_grad, all_shard_grad, - atol=1e-5), f"shard model grad is not equal to origin model grad\n{org_grad}\n{all_shard_grad}" torch.cuda.empty_cache() -@parameterize('enable_fused_normalization', [True, False]) -@parameterize('enable_tensor_parallelism', [True, False]) -@parameterize('use_lazy_init', [False, True]) +@parameterize('test_config', [{ + 'tp_size': 1, + 'pp_size': 2, + 'num_microbatches': 4, + 'use_lazy_init': True +}, { + 'tp_size': 2, + 'pp_size': 2, + 'num_microbatches': 4, + 'enable_fused_normalization': False, + 'use_lazy_init': False +}, { + 'tp_size': 4, + 'pp_size': 1, + 'enable_fused_normalization': True, + 'use_lazy_init': False +}]) @clear_cache_before_run() -def run_gpt2_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init): - sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt') - for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism, - use_lazy_init) - check_state_dict(org_model, sharded_model, name=name) - check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) +def run_gpt2_test(test_config): + # TODO: add plugin_config for TP+DP after supporting & debugging it + # {'tp_size': 2, 'pp_size': 1, 'enable_fused_normalization': True} + + sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt') + test_config['precision'] = 'float' # Do not use fp16/bf16 in testing + + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) + + clear_layout_converter() torch.cuda.empty_cache() @@ -93,7 +188,7 @@ def check_gpt2(rank, world_size, port): @rerun_if_address_is_in_use() @clear_cache_before_run() def test_gpt2(): - spawn(check_gpt2, 2) + spawn(check_gpt2, 4) if __name__ == "__main__": diff --git a/tests/test_shardformer/test_model/test_shard_gpt2_pipeline.py b/tests/test_shardformer/test_model/test_shard_gpt2_pipeline.py deleted file mode 100644 index d5453ee72..000000000 --- a/tests/test_shardformer/test_model/test_shard_gpt2_pipeline.py +++ /dev/null @@ -1,72 +0,0 @@ -import pytest -import torch - -import colossalai -from colossalai.cluster import ProcessGroupMesh -from colossalai.logging import disable_existing_loggers -from colossalai.pipeline.stage_manager import PipelineStageManager -from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn -from tests.kit.model_zoo import model_zoo -from tests.test_shardformer.test_model._utils import build_pipeline_model - - -def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn): - # TODO: add tests for forward/backward later - pass - - -@parameterize('enable_tensor_parallelism', [False]) -@parameterize('enable_fused_normalization', [False]) -@parameterize('use_lazy_init', [False]) -#TODO: merge this into test_shard_gpt2 -def run_gpt2_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init): - DP_DIM, PP_DIM = 0, 1 - DP_SIZE, PP_SIZE = 2, 2 - pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE) - stage_manager = PipelineStageManager(pg_mesh, PP_DIM) - - sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt') - for name, (model_fn, data_gen_fn, _, _, _) in sub_model_zoo.items(): - inputs = data_gen_fn() - inputs = {k: v.cuda() for k, v in inputs.items()} - _, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization, - enable_tensor_parallelism, use_lazy_init) - input_ids = inputs['input_ids'] - batch_size, seq_len = input_ids.shape - hidden_size = sharded_model.config.n_embd - hidden_state_shape = (batch_size, seq_len, hidden_size) - - if not stage_manager.is_first_stage(): - # change inputs if not the first stage - hidden_states = torch.zeros(*hidden_state_shape).cuda() - inputs['input_ids'] = None - inputs['hidden_states'] = hidden_states - - sharded_model.train() - output = sharded_model(**inputs) - if stage_manager.is_last_stage(): - if name == 'transformers_gpt': - assert output[0].shape == hidden_state_shape - else: - assert output.loss is not None - else: - assert output['hidden_states'].shape == hidden_state_shape - - torch.cuda.empty_cache() - - -def check_gpt2(rank, world_size, port): - disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - run_gpt2_test() - - -@pytest.mark.dist -@rerun_if_address_is_in_use() -@clear_cache_before_run() -def test_gpt2(): - spawn(check_gpt2, 4) - - -if __name__ == "__main__": - test_gpt2()