diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 42942aaeb..28a19af0c 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -148,7 +148,10 @@ class HybridParallelPlugin(PipelinePluginBase): precision: str = 'fp16', zero_stage: int = 0, cpu_offload: bool = False, + enable_all_optimization: bool = False, enable_fused_normalization: bool = False, + enable_flash_attention: bool = False, + enable_jit_fused: bool = False, num_microbatches: Optional[int] = None, initial_scale: float = 2**16, min_scale: float = 1, @@ -171,7 +174,10 @@ class HybridParallelPlugin(PipelinePluginBase): self.precision = precision self.zero_stage = zero_stage self.cpu_offload = cpu_offload + self.enable_all_optimization = enable_all_optimization self.enable_fused_normalization = enable_fused_normalization + self.enable_flash_attention = enable_flash_attention + self.enable_jit_fused = enable_jit_fused self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size) self.stage_manager = None self.schedule = None @@ -186,7 +192,10 @@ class HybridParallelPlugin(PipelinePluginBase): self.shard_config = ShardConfig(tensor_parallel_process_group=self.tp_group, pipeline_stage_manager=self.stage_manager, enable_tensor_parallelism=self.tp_size > 1, - enable_fused_normalization=self.enable_fused_normalization) + enable_all_optimization=self.enable_all_optimization, + enable_fused_normalization=self.enable_fused_normalization, + enable_flash_attention=self.enable_flash_attention, + enable_jit_fused=self.enable_jit_fused) self.amp_config = dict( initial_scale=initial_scale, growth_factor=growth_factor, diff --git a/requirements/requirements-test.txt b/requirements/requirements-test.txt index 510af5f3c..a37d00326 100644 --- a/requirements/requirements-test.txt +++ b/requirements/requirements-test.txt @@ -19,4 +19,4 @@ ninja flash_attn>=2.0 datasets ninja -flash-attn +flash-attn>=2.0 diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py index 98cdc5a4b..cce21809d 100644 --- a/tests/test_shardformer/test_model/_utils.py +++ b/tests/test_shardformer/test_model/_utils.py @@ -1,6 +1,5 @@ import copy from contextlib import nullcontext -from typing import Optional from typing import Any, Callable, Dict, List, Optional import torch @@ -16,8 +15,8 @@ from colossalai.booster.plugin import HybridParallelPlugin from colossalai.lazy import LazyInitContext from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer import ShardConfig, ShardFormer -from colossalai.shardformer.policies.auto_policy import Policy from colossalai.shardformer._utils import getattr_ +from colossalai.shardformer.policies.auto_policy import Policy from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor @@ -156,10 +155,12 @@ def run_forward_backward_with_hybrid_plugin(org_model: Module, sharded_model: Mo else: data = {k: v.cuda() for k, v in data.items()} sharded_output = sharded_model(**data) + sharded_loss = criterion(sharded_output) - sharded_loss.backward() + sharded_optimizer.backward(sharded_loss) org_model.train() + data = {k: v.cuda() for k, v in data.items()} org_output = org_model(**data) org_loss = criterion(org_output) org_loss.backward() @@ -181,12 +182,12 @@ def check_output_hidden_state(org_output: Tensor, if stage_manager and stage_manager.is_last_stage(): sharded_hidden_state = torch.cat([output.last_hidden_state for output in sharded_output['outputs']], dim=0) - assert torch.allclose(org_hidden_state, sharded_hidden_state, atol=atol, rtol=rtol), \ + assert torch.allclose(org_hidden_state.float(), sharded_hidden_state.float(), atol=atol, rtol=rtol), \ f"shard model's output hidden state is not equal to origin model's last hidden state\n{org_hidden_state}\n{sharded_hidden_state}" def check_loss(org_loss: Tensor, sharded_loss: Tensor, atol: float = 1e-5, rtol: float = 1e-3): - assert torch.allclose(org_loss, sharded_loss, atol=atol, rtol=rtol), \ + assert torch.allclose(org_loss.float(), sharded_loss.float(), atol=atol, rtol=rtol), \ f"shard model loss is not equal to origin model loss\n{org_loss}\n{sharded_loss}" @@ -213,7 +214,7 @@ def check_weight(org_model: Module, if verbose and dist.get_rank() == 0: print(f"'{suffix}' weight: {org_weight}, {sharded_weight}") - assert torch.allclose(org_weight, sharded_weight, atol=atol, rtol=rtol), \ + assert torch.allclose(org_weight.float(), sharded_weight.float(), atol=atol, rtol=rtol), \ f"shard model weight is not equal to origin model weight\n{org_weight}\n{sharded_weight}" @@ -244,6 +245,7 @@ def check_grad(org_model: Module, if verbose and dist.get_rank() == 0: print(f"'{suffix}' grad: {org_grad}, {shard_grad}") + assert torch.allclose( - org_grad, shard_grad, rtol=rtol, atol=atol + org_grad.float(), shard_grad.float(), rtol=rtol, atol=atol ), f"error attribute '{suffix}', orgin model grad is not equal to shard model grad\n{org_grad}\n{shard_grad}" diff --git a/tests/test_shardformer/test_model/test_shard_gpt2.py b/tests/test_shardformer/test_model/test_shard_gpt2.py index 1882bf782..3ac8fa26d 100644 --- a/tests/test_shardformer/test_model/test_shard_gpt2.py +++ b/tests/test_shardformer/test_model/test_shard_gpt2.py @@ -3,6 +3,7 @@ import torch from torch import distributed as dist import colossalai +from colossalai.booster.plugin.hybrid_parallel_plugin import HybridParallelModule from colossalai.logging import disable_existing_loggers from colossalai.tensor.d_tensor.api import clear_layout_converter from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn @@ -38,33 +39,49 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, # check last hidden state & loss if stage_manager is None or stage_manager.is_last_stage(): + if test_config['precision'] == 'fp32': + atol, rtol = 1e-5, 1e-3 + else: + atol, rtol = 5e-3, 5e-3 if org_model.__class__.__name__ == 'GPT2Model': - check_output_hidden_state(org_output, sharded_output, stage_manager, atol=1e-5, rtol=1e-3) + check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol) - check_loss(org_loss, sharded_loss, atol=1e-5, rtol=1e-3) + # check loss + check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) + + def unwrap(module): + if isinstance(module, HybridParallelModule): + module = module.unwrap() + if module.__class__.__name__ == 'GPT2Model': + return module + return module.transformer # unwrap model - if org_model.__class__.__name__ == 'GPT2Model': - gpt2 = org_model - sharded_gpt2 = sharded_model.unwrap() - else: - gpt2 = org_model.transformer - sharded_gpt2 = sharded_model.unwrap().transformer + gpt2 = unwrap(org_model) + sharded_gpt2 = unwrap(sharded_model) col_layer_for_check = ['h[0].mlp.c_fc'] row_layer_for_check = ['wte', 'h[0].mlp.c_proj'] # check grad + if test_config['precision'] == 'fp32': + atol, rtol = 1e-4, 1e-3 + else: + atol, rtol = 5e-3, 5e-3 if stage_manager is None or stage_manager.is_first_stage(): - check_grad(gpt2, sharded_gpt2, col_layer_for_check, tp_group, atol=1e-4, rtol=1e-3, dim=1, verbose=False) - check_grad(gpt2, sharded_gpt2, row_layer_for_check, tp_group, atol=1e-4, rtol=1e-3, dim=0, verbose=False) + check_grad(gpt2, sharded_gpt2, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False) + check_grad(gpt2, sharded_gpt2, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False) # check weights after optimizer.step() org_optimizer.step() sharded_optimizer.step() + if test_config['precision'] == 'fp32': + atol, rtol = 5e-3, 1e-3 + else: + atol, rtol = 5e-3, 5e-3 if stage_manager is None or stage_manager.is_first_stage(): - check_weight(gpt2, sharded_gpt2, col_layer_for_check, tp_group, atol=5e-3, rtol=1e-3, dim=1, verbose=False) + check_weight(gpt2, sharded_gpt2, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False) torch.cuda.empty_cache() @@ -73,29 +90,31 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, 'tp_size': 2, 'pp_size': 2, 'num_microbatches': 4, - 'enable_fused_normalization': True, - 'use_lazy_init': True + 'enable_all_optimization': True, + 'use_lazy_init': True, + 'precision': 'fp32', }, { 'tp_size': 1, 'pp_size': 2, 'num_microbatches': 4, - 'use_lazy_init': False + 'enable_all_optimization': True, + 'use_lazy_init': False, + 'precision': 'fp16', + 'initial_scale': 1, }, { 'tp_size': 4, 'pp_size': 1, - 'enable_fused_normalization': True, - 'use_lazy_init': False + 'enable_all_optimization': True, + 'use_lazy_init': False, + 'precision': 'fp32', }]) @clear_cache_before_run() def run_gpt2_test(test_config): # TODO: add test_config for TP+DP after supporting & debugging it - # {'tp_size': 2, 'pp_size': 1, 'enable_fused_normalization': True} - - # TODO: add test_config for flash attention & jit operator after supporting + # TODO: check and debug TP+AMP sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt') - test_config['precision'] = 'float' # Do not use fp16/bf16 in testing for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)