[shardformer] test all optimizations (#4399)

[shardformer] test all optimizations

[shardformer] test all optimizations

[shardformer] test all optimizations
pull/4445/head
flybird1111 2023-08-10 13:59:30 +08:00 committed by Hongxin Liu
parent 7a3dfd0c64
commit d2cd48e0be
4 changed files with 59 additions and 29 deletions

View File

@ -148,7 +148,10 @@ class HybridParallelPlugin(PipelinePluginBase):
precision: str = 'fp16', precision: str = 'fp16',
zero_stage: int = 0, zero_stage: int = 0,
cpu_offload: bool = False, cpu_offload: bool = False,
enable_all_optimization: bool = False,
enable_fused_normalization: bool = False, enable_fused_normalization: bool = False,
enable_flash_attention: bool = False,
enable_jit_fused: bool = False,
num_microbatches: Optional[int] = None, num_microbatches: Optional[int] = None,
initial_scale: float = 2**16, initial_scale: float = 2**16,
min_scale: float = 1, min_scale: float = 1,
@ -171,7 +174,10 @@ class HybridParallelPlugin(PipelinePluginBase):
self.precision = precision self.precision = precision
self.zero_stage = zero_stage self.zero_stage = zero_stage
self.cpu_offload = cpu_offload self.cpu_offload = cpu_offload
self.enable_all_optimization = enable_all_optimization
self.enable_fused_normalization = enable_fused_normalization self.enable_fused_normalization = enable_fused_normalization
self.enable_flash_attention = enable_flash_attention
self.enable_jit_fused = enable_jit_fused
self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size) self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size)
self.stage_manager = None self.stage_manager = None
self.schedule = None self.schedule = None
@ -186,7 +192,10 @@ class HybridParallelPlugin(PipelinePluginBase):
self.shard_config = ShardConfig(tensor_parallel_process_group=self.tp_group, self.shard_config = ShardConfig(tensor_parallel_process_group=self.tp_group,
pipeline_stage_manager=self.stage_manager, pipeline_stage_manager=self.stage_manager,
enable_tensor_parallelism=self.tp_size > 1, enable_tensor_parallelism=self.tp_size > 1,
enable_fused_normalization=self.enable_fused_normalization) enable_all_optimization=self.enable_all_optimization,
enable_fused_normalization=self.enable_fused_normalization,
enable_flash_attention=self.enable_flash_attention,
enable_jit_fused=self.enable_jit_fused)
self.amp_config = dict( self.amp_config = dict(
initial_scale=initial_scale, initial_scale=initial_scale,
growth_factor=growth_factor, growth_factor=growth_factor,

View File

@ -19,4 +19,4 @@ ninja
flash_attn>=2.0 flash_attn>=2.0
datasets datasets
ninja ninja
flash-attn flash-attn>=2.0

View File

@ -1,6 +1,5 @@
import copy import copy
from contextlib import nullcontext from contextlib import nullcontext
from typing import Optional
from typing import Any, Callable, Dict, List, Optional from typing import Any, Callable, Dict, List, Optional
import torch import torch
@ -16,8 +15,8 @@ from colossalai.booster.plugin import HybridParallelPlugin
from colossalai.lazy import LazyInitContext from colossalai.lazy import LazyInitContext
from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import ShardConfig, ShardFormer from colossalai.shardformer import ShardConfig, ShardFormer
from colossalai.shardformer.policies.auto_policy import Policy
from colossalai.shardformer._utils import getattr_ from colossalai.shardformer._utils import getattr_
from colossalai.shardformer.policies.auto_policy import Policy
from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor
@ -156,10 +155,12 @@ def run_forward_backward_with_hybrid_plugin(org_model: Module, sharded_model: Mo
else: else:
data = {k: v.cuda() for k, v in data.items()} data = {k: v.cuda() for k, v in data.items()}
sharded_output = sharded_model(**data) sharded_output = sharded_model(**data)
sharded_loss = criterion(sharded_output) sharded_loss = criterion(sharded_output)
sharded_loss.backward() sharded_optimizer.backward(sharded_loss)
org_model.train() org_model.train()
data = {k: v.cuda() for k, v in data.items()}
org_output = org_model(**data) org_output = org_model(**data)
org_loss = criterion(org_output) org_loss = criterion(org_output)
org_loss.backward() org_loss.backward()
@ -181,12 +182,12 @@ def check_output_hidden_state(org_output: Tensor,
if stage_manager and stage_manager.is_last_stage(): if stage_manager and stage_manager.is_last_stage():
sharded_hidden_state = torch.cat([output.last_hidden_state for output in sharded_output['outputs']], dim=0) sharded_hidden_state = torch.cat([output.last_hidden_state for output in sharded_output['outputs']], dim=0)
assert torch.allclose(org_hidden_state, sharded_hidden_state, atol=atol, rtol=rtol), \ assert torch.allclose(org_hidden_state.float(), sharded_hidden_state.float(), atol=atol, rtol=rtol), \
f"shard model's output hidden state is not equal to origin model's last hidden state\n{org_hidden_state}\n{sharded_hidden_state}" f"shard model's output hidden state is not equal to origin model's last hidden state\n{org_hidden_state}\n{sharded_hidden_state}"
def check_loss(org_loss: Tensor, sharded_loss: Tensor, atol: float = 1e-5, rtol: float = 1e-3): def check_loss(org_loss: Tensor, sharded_loss: Tensor, atol: float = 1e-5, rtol: float = 1e-3):
assert torch.allclose(org_loss, sharded_loss, atol=atol, rtol=rtol), \ assert torch.allclose(org_loss.float(), sharded_loss.float(), atol=atol, rtol=rtol), \
f"shard model loss is not equal to origin model loss\n{org_loss}\n{sharded_loss}" f"shard model loss is not equal to origin model loss\n{org_loss}\n{sharded_loss}"
@ -213,7 +214,7 @@ def check_weight(org_model: Module,
if verbose and dist.get_rank() == 0: if verbose and dist.get_rank() == 0:
print(f"'{suffix}' weight: {org_weight}, {sharded_weight}") print(f"'{suffix}' weight: {org_weight}, {sharded_weight}")
assert torch.allclose(org_weight, sharded_weight, atol=atol, rtol=rtol), \ assert torch.allclose(org_weight.float(), sharded_weight.float(), atol=atol, rtol=rtol), \
f"shard model weight is not equal to origin model weight\n{org_weight}\n{sharded_weight}" f"shard model weight is not equal to origin model weight\n{org_weight}\n{sharded_weight}"
@ -244,6 +245,7 @@ def check_grad(org_model: Module,
if verbose and dist.get_rank() == 0: if verbose and dist.get_rank() == 0:
print(f"'{suffix}' grad: {org_grad}, {shard_grad}") print(f"'{suffix}' grad: {org_grad}, {shard_grad}")
assert torch.allclose( assert torch.allclose(
org_grad, shard_grad, rtol=rtol, atol=atol org_grad.float(), shard_grad.float(), rtol=rtol, atol=atol
), f"error attribute '{suffix}', orgin model grad is not equal to shard model grad\n{org_grad}\n{shard_grad}" ), f"error attribute '{suffix}', orgin model grad is not equal to shard model grad\n{org_grad}\n{shard_grad}"

View File

@ -3,6 +3,7 @@ import torch
from torch import distributed as dist from torch import distributed as dist
import colossalai import colossalai
from colossalai.booster.plugin.hybrid_parallel_plugin import HybridParallelModule
from colossalai.logging import disable_existing_loggers from colossalai.logging import disable_existing_loggers
from colossalai.tensor.d_tensor.api import clear_layout_converter from colossalai.tensor.d_tensor.api import clear_layout_converter
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
@ -38,33 +39,49 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
# check last hidden state & loss # check last hidden state & loss
if stage_manager is None or stage_manager.is_last_stage(): if stage_manager is None or stage_manager.is_last_stage():
if test_config['precision'] == 'fp32':
atol, rtol = 1e-5, 1e-3
else:
atol, rtol = 5e-3, 5e-3
if org_model.__class__.__name__ == 'GPT2Model': if org_model.__class__.__name__ == 'GPT2Model':
check_output_hidden_state(org_output, sharded_output, stage_manager, atol=1e-5, rtol=1e-3) check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol)
check_loss(org_loss, sharded_loss, atol=1e-5, rtol=1e-3) # check loss
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
def unwrap(module):
if isinstance(module, HybridParallelModule):
module = module.unwrap()
if module.__class__.__name__ == 'GPT2Model':
return module
return module.transformer
# unwrap model # unwrap model
if org_model.__class__.__name__ == 'GPT2Model': gpt2 = unwrap(org_model)
gpt2 = org_model sharded_gpt2 = unwrap(sharded_model)
sharded_gpt2 = sharded_model.unwrap()
else:
gpt2 = org_model.transformer
sharded_gpt2 = sharded_model.unwrap().transformer
col_layer_for_check = ['h[0].mlp.c_fc'] col_layer_for_check = ['h[0].mlp.c_fc']
row_layer_for_check = ['wte', 'h[0].mlp.c_proj'] row_layer_for_check = ['wte', 'h[0].mlp.c_proj']
# check grad # check grad
if test_config['precision'] == 'fp32':
atol, rtol = 1e-4, 1e-3
else:
atol, rtol = 5e-3, 5e-3
if stage_manager is None or stage_manager.is_first_stage(): if stage_manager is None or stage_manager.is_first_stage():
check_grad(gpt2, sharded_gpt2, col_layer_for_check, tp_group, atol=1e-4, rtol=1e-3, dim=1, verbose=False) check_grad(gpt2, sharded_gpt2, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False)
check_grad(gpt2, sharded_gpt2, row_layer_for_check, tp_group, atol=1e-4, rtol=1e-3, dim=0, verbose=False) check_grad(gpt2, sharded_gpt2, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False)
# check weights after optimizer.step() # check weights after optimizer.step()
org_optimizer.step() org_optimizer.step()
sharded_optimizer.step() sharded_optimizer.step()
if test_config['precision'] == 'fp32':
atol, rtol = 5e-3, 1e-3
else:
atol, rtol = 5e-3, 5e-3
if stage_manager is None or stage_manager.is_first_stage(): if stage_manager is None or stage_manager.is_first_stage():
check_weight(gpt2, sharded_gpt2, col_layer_for_check, tp_group, atol=5e-3, rtol=1e-3, dim=1, verbose=False) check_weight(gpt2, sharded_gpt2, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False)
torch.cuda.empty_cache() torch.cuda.empty_cache()
@ -73,29 +90,31 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
'tp_size': 2, 'tp_size': 2,
'pp_size': 2, 'pp_size': 2,
'num_microbatches': 4, 'num_microbatches': 4,
'enable_fused_normalization': True, 'enable_all_optimization': True,
'use_lazy_init': True 'use_lazy_init': True,
'precision': 'fp32',
}, { }, {
'tp_size': 1, 'tp_size': 1,
'pp_size': 2, 'pp_size': 2,
'num_microbatches': 4, 'num_microbatches': 4,
'use_lazy_init': False 'enable_all_optimization': True,
'use_lazy_init': False,
'precision': 'fp16',
'initial_scale': 1,
}, { }, {
'tp_size': 4, 'tp_size': 4,
'pp_size': 1, 'pp_size': 1,
'enable_fused_normalization': True, 'enable_all_optimization': True,
'use_lazy_init': False 'use_lazy_init': False,
'precision': 'fp32',
}]) }])
@clear_cache_before_run() @clear_cache_before_run()
def run_gpt2_test(test_config): def run_gpt2_test(test_config):
# TODO: add test_config for TP+DP after supporting & debugging it # TODO: add test_config for TP+DP after supporting & debugging it
# {'tp_size': 2, 'pp_size': 1, 'enable_fused_normalization': True} # TODO: check and debug TP+AMP
# TODO: add test_config for flash attention & jit operator after supporting
sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt') sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt')
test_config['precision'] = 'float' # Do not use fp16/bf16 in testing
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)