[pipeline] rewrite bert tests and fix some bugs (#4409)

* add pipeline policy and bert forward to be done

* add bertmodel pipeline forward and make tests

* add Bert_Policy and test for policy

* update formatting

* update formatting

* update the code

* fix bugs

* fix name confilt

* add bloom model and policy ,revise the base class of policy

* revise

* revision

* add bert_for_pretraining

* add bert_for_pretraining forward and policy

* fix typos

* cancel warning

* change the imediate output to default dict

* change the default output of get_shared_params

* rewrite bert test

* rewrite bert test

* fix some bugs

* del pipeline tests

* del pipeline tests

* del useless print

* del useless print

* rewrite data repeats
pull/4445/head
Jianghai 2023-08-11 10:32:53 +08:00 committed by Hongxin Liu
parent d2cd48e0be
commit 7596e9ae08
4 changed files with 83 additions and 154 deletions

View File

@ -104,7 +104,8 @@ def data_gen_for_qa():
output_transform_fn = lambda x: x
# define loss funciton
loss_fn_for_bert_model = lambda x: x.pooler_output.sum()
loss_fn_for_bert_model = lambda x: torch.nn.functional.mse_loss(x.last_hidden_state, torch.ones_like(x.last_hidden_state
))
loss_fn = lambda x: x.loss
config = transformers.BertConfig(hidden_size=128,

View File

@ -131,6 +131,8 @@ def build_model_from_hybrid_plugin(model_fn: Callable, loss_fn: Callable, test_c
def run_forward_backward_with_hybrid_plugin(org_model: Module, sharded_model: Module, sharded_optimizer: Optimizer,
data_gen_fn: Callable, output_transform_fn: Callable, criterion: Callable,
booster: Booster):
org_model.cuda()
sharded_model.cuda()
def _criterion(outputs, inputs):
outputs = output_transform_fn(outputs)
@ -141,7 +143,8 @@ def run_forward_backward_with_hybrid_plugin(org_model: Module, sharded_model: Mo
sharded_model.train()
if booster.plugin.stage_manager is not None:
data = {
k: v.to('cuda').repeat(4, 1) if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v
k: v.to('cuda').repeat(*([4] + [1] *
(v.dim() - 1))) if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v
for k, v in data.items()
}
data_iter = iter([data])
@ -162,6 +165,7 @@ def run_forward_backward_with_hybrid_plugin(org_model: Module, sharded_model: Mo
org_model.train()
data = {k: v.cuda() for k, v in data.items()}
org_output = org_model(**data)
org_loss = criterion(org_output)
org_loss.backward()
@ -226,7 +230,6 @@ def check_grad(org_model: Module,
atol: float = 1e-5,
rtol: float = 1e-3,
verbose: bool = False):
for suffix in layer_suffix:
org_grad = getattr_(org_model, suffix).weight.grad
shard_grad = getattr_(sharded_model, suffix).weight.grad
@ -242,7 +245,6 @@ def check_grad(org_model: Module,
# embedding may be resized when using tensor parallel
if shard_grad.shape[0] > org_grad.shape[0]:
shard_grad = shard_grad[:org_grad.shape[0], :]
if verbose and dist.get_rank() == 0:
print(f"'{suffix}' grad: {org_grad}, {shard_grad}")

View File

@ -1,65 +1,98 @@
import pytest
import torch
from torch import distributed as dist
import colossalai
from colossalai.cluster import ProcessGroupMesh
from colossalai.logging import disable_existing_loggers
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.policies.auto_policy import get_autopolicy
from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor
from colossalai.testing import (
assert_hf_output_close,
clear_cache_before_run,
parameterize,
rerun_if_address_is_in_use,
spawn,
)
from colossalai.shardformer.layer.utils import Randomizer
from colossalai.tensor.d_tensor.api import clear_layout_converter
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo
from tests.test_shardformer.test_model._utils import build_model, check_grad, check_state_dict, run_forward
from tests.test_shardformer.test_model._utils import (
build_model_from_hybrid_plugin,
check_grad,
check_loss,
check_output_hidden_state,
check_weight,
run_forward_backward_with_hybrid_plugin,
)
def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
# unwarp model
def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config):
org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = \
build_model_from_hybrid_plugin(model_fn, loss_fn, test_config)
org_loss, org_output, sharded_loss, sharded_output = \
run_forward_backward_with_hybrid_plugin(
org_model,
sharded_model,
sharded_optimizer,
data_gen_fn,
output_transform_fn,
criterion,
booster)
stage_manager = booster.plugin.stage_manager
tp_group = booster.plugin.tp_group
# check last hidden state & loss
if stage_manager is None or stage_manager.is_last_stage():
if org_model.__class__.__name__ == 'BertModel':
check_output_hidden_state(org_output, sharded_output, stage_manager, atol=1e-5, rtol=1e-3)
check_loss(org_loss, sharded_loss, atol=1e-5, rtol=1e-3)
# unwrap model
if org_model.__class__.__name__ == 'BertModel':
bert = org_model
sharded_bert = sharded_model
sharded_bert = sharded_model.unwrap()
else:
bert = org_model.bert
sharded_bert = sharded_model.bert
sharded_bert = sharded_model.unwrap().bert
# check forward
org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn,
output_transform_fn, loss_fn)
assert_hf_output_close(org_output, shard_output)
col_layer_for_check = ['encoder.layer[0].output.dense']
row_layer_for_check = ['embeddings.word_embeddings', 'encoder.layer[0].intermediate.dense']
# do backward
org_loss.backward()
shard_loss.backward()
if stage_manager is None or stage_manager.is_first_stage():
#check_weight(bert.embeddings.word_embeddings, sharded_bert.embeddings.word_embeddings, tp_group, atol=1e-5, rtol=1e-3)
#check_weight(bert.encoder.layer[0].attention.self.query, sharded_bert.encoder.layer[0].attention.self.query, tp_group, atol=5e-3, rtol=1e-3)
check_grad(bert, sharded_bert, col_layer_for_check, tp_group, atol=1e-4, rtol=1e-3, dim=1, verbose=False)
check_grad(bert, sharded_bert, row_layer_for_check, tp_group, atol=1e-4, rtol=1e-3, dim=0, verbose=False)
assert torch.allclose(org_loss, shard_loss,
atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}"
# check weights after optimizer.step()
org_optimizer.step()
sharded_optimizer.step()
if stage_manager is None or stage_manager.is_first_stage():
check_weight(bert, sharded_bert, col_layer_for_check, tp_group, atol=5e-3, rtol=1e-3, dim=1, verbose=False)
# check grad
col_layer_for_check = ['encoder.layer[0].attention.self.query', 'embeddings.word_embeddings']
row_layer_for_check = ['encoder.layer[0].attention.output.dense']
check_grad(bert, sharded_bert, col_layer_for_check, atol=1e-7, rtol=1e-3, dim=0, verbose=False)
check_grad(bert, sharded_bert, row_layer_for_check, atol=1e-7, rtol=1e-3, dim=1, verbose=False)
torch.cuda.empty_cache()
@parameterize('enable_fused_normalization', [True, False])
@parameterize('enable_tensor_parallelism', [True, False])
@parameterize('enable_flash_attention', [True, False])
@parameterize('enable_jit_fused', [True, False])
@parameterize('use_lazy_init', [False, True])
def run_bert_test(enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention, enable_jit_fused,
use_lazy_init):
@parameterize('test_config', [{
'tp_size': 1,
'pp_size': 2,
'num_microbatches': 4,
'use_lazy_init': True
}, {
'tp_size': 2,
'pp_size': 2,
'num_microbatches': 4,
'enable_fused_normalization': False,
'use_lazy_init': False
}, {
'tp_size': 4,
'pp_size': 1,
'enable_fused_normalization': True,
'use_lazy_init': False
}])
def run_bert_test(test_config):
sub_model_zoo = model_zoo.get_sub_registry('transformers_bert')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism,
enable_flash_attention, enable_jit_fused, use_lazy_init)
check_state_dict(org_model, sharded_model, name=name)
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
test_config['precision'] = 'float'
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
clear_layout_converter()
Randomizer.reset_index()
torch.cuda.empty_cache()
@ -73,7 +106,7 @@ def check_bert(rank, world_size, port):
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_bert():
spawn(check_bert, 2)
spawn(check_bert, 4)
if __name__ == "__main__":

View File

@ -1,107 +0,0 @@
import pytest
import torch
import colossalai
from colossalai.cluster import ProcessGroupMesh
from colossalai.logging import disable_existing_loggers
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.policies.auto_policy import get_autopolicy
from colossalai.shardformer.shard import ShardConfig
from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor
from colossalai.testing import (
assert_hf_output_close,
clear_cache_before_run,
parameterize,
rerun_if_address_is_in_use,
spawn,
)
from tests.kit.model_zoo import model_zoo
from tests.test_shardformer.test_model._utils import build_model, build_pipeline_model, run_forward
def check_bert_model_policy(name, model: torch.nn.Module, stage_manager: PipelineStageManager):
stage_manager = stage_manager
policy = get_autopolicy(model)
policy.set_model(model)
model_config = ShardConfig(pipeline_stage_manager=stage_manager, enable_tensor_parallelism=False)
policy.set_shard_config(model_config)
layers = policy.get_held_layers()
if stage_manager.is_first_stage():
assert len(layers) == 1 + 1
else:
if name == "transformers_bert":
assert len(layers) == 1 + 1
elif name in [
"transformers_bert_for_sequence_classification", "transformers_bert_for_token_classification",
"transformers_bert_for_mcq"
]:
assert len(layers) == 1 + 3
else:
assert len(layers) == 1 + 2
def check_bert_model_pipeline_forward(name, sharded_model, stage_manager: PipelineStageManager):
if name == 'transformers_bert_for_mcq':
x = torch.randint(0, 1000, (2, 3, 3)).cuda()
attention_mask = torch.ones_like(x).cuda()
if stage_manager.stage == 0:
output = sharded_model(input_ids=x, attention_mask=attention_mask, stage_manager=stage_manager)
assert output['hidden_states'].shape == (6, 3, 128)
else:
hidden_states = torch.randint(0, 1000, (6, 3, 128)).to(torch.float32).cuda()
output = sharded_model(input_ids=x,
hidden_states=hidden_states,
attention_mask=attention_mask,
stage_manager=stage_manager)
assert output[0].shape == (2, 3)
else:
x = torch.randint(0, 1000, (2, 3)).cuda()
# one batch, 2 single sentences, each sentence has 3 tokens
hidden_states = torch.randint(0, 1000, (2, 3, 128)).to(torch.float32).cuda()
if stage_manager.stage == 0:
attention_mask = torch.ones_like(x).cuda()
output = sharded_model(input_ids=x, attention_mask=attention_mask, stage_manager=stage_manager)
assert output['hidden_states'].shape == (2, 3, 128)
else:
attention_mask = torch.ones((2, 3)).cuda()
output = sharded_model(hidden_states=hidden_states,
attention_mask=attention_mask,
stage_manager=stage_manager)
assert output[0].shape[0] == 2
@parameterize('enable_fused_normalization', [False])
@parameterize('enable_tensor_parallelism', [False])
@parameterize('use_lazy_init', [False])
#TODO: merge this into test_shard_bert
def run_bert_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init):
PP_DIM = 0
PP_SIZE = 2
pg_mesh = ProcessGroupMesh(PP_SIZE)
stage_manager = PipelineStageManager(pg_mesh, PP_DIM)
sub_model_zoo = model_zoo.get_sub_registry('transformers_bert')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
org_model, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization,
enable_tensor_parallelism, use_lazy_init)
check_bert_model_policy(name, org_model, stage_manager)
check_bert_model_pipeline_forward(name, sharded_model, stage_manager)
torch.cuda.empty_cache()
def check_bert(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_bert_test()
@pytest.mark.dist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_bert():
spawn(check_bert, 2)
if __name__ == "__main__":
test_bert()