mirror of https://github.com/hpcaitech/ColossalAI
[pipeline] rewrite bert tests and fix some bugs (#4409)
* add pipeline policy and bert forward to be done * add bertmodel pipeline forward and make tests * add Bert_Policy and test for policy * update formatting * update formatting * update the code * fix bugs * fix name confilt * add bloom model and policy ,revise the base class of policy * revise * revision * add bert_for_pretraining * add bert_for_pretraining forward and policy * fix typos * cancel warning * change the imediate output to default dict * change the default output of get_shared_params * rewrite bert test * rewrite bert test * fix some bugs * del pipeline tests * del pipeline tests * del useless print * del useless print * rewrite data repeatspull/4445/head
parent
d2cd48e0be
commit
7596e9ae08
|
@ -104,7 +104,8 @@ def data_gen_for_qa():
|
|||
output_transform_fn = lambda x: x
|
||||
|
||||
# define loss funciton
|
||||
loss_fn_for_bert_model = lambda x: x.pooler_output.sum()
|
||||
loss_fn_for_bert_model = lambda x: torch.nn.functional.mse_loss(x.last_hidden_state, torch.ones_like(x.last_hidden_state
|
||||
))
|
||||
loss_fn = lambda x: x.loss
|
||||
|
||||
config = transformers.BertConfig(hidden_size=128,
|
||||
|
|
|
@ -131,6 +131,8 @@ def build_model_from_hybrid_plugin(model_fn: Callable, loss_fn: Callable, test_c
|
|||
def run_forward_backward_with_hybrid_plugin(org_model: Module, sharded_model: Module, sharded_optimizer: Optimizer,
|
||||
data_gen_fn: Callable, output_transform_fn: Callable, criterion: Callable,
|
||||
booster: Booster):
|
||||
org_model.cuda()
|
||||
sharded_model.cuda()
|
||||
|
||||
def _criterion(outputs, inputs):
|
||||
outputs = output_transform_fn(outputs)
|
||||
|
@ -141,7 +143,8 @@ def run_forward_backward_with_hybrid_plugin(org_model: Module, sharded_model: Mo
|
|||
sharded_model.train()
|
||||
if booster.plugin.stage_manager is not None:
|
||||
data = {
|
||||
k: v.to('cuda').repeat(4, 1) if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v
|
||||
k: v.to('cuda').repeat(*([4] + [1] *
|
||||
(v.dim() - 1))) if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v
|
||||
for k, v in data.items()
|
||||
}
|
||||
data_iter = iter([data])
|
||||
|
@ -162,6 +165,7 @@ def run_forward_backward_with_hybrid_plugin(org_model: Module, sharded_model: Mo
|
|||
org_model.train()
|
||||
data = {k: v.cuda() for k, v in data.items()}
|
||||
org_output = org_model(**data)
|
||||
|
||||
org_loss = criterion(org_output)
|
||||
org_loss.backward()
|
||||
|
||||
|
@ -226,7 +230,6 @@ def check_grad(org_model: Module,
|
|||
atol: float = 1e-5,
|
||||
rtol: float = 1e-3,
|
||||
verbose: bool = False):
|
||||
|
||||
for suffix in layer_suffix:
|
||||
org_grad = getattr_(org_model, suffix).weight.grad
|
||||
shard_grad = getattr_(sharded_model, suffix).weight.grad
|
||||
|
@ -242,7 +245,6 @@ def check_grad(org_model: Module,
|
|||
# embedding may be resized when using tensor parallel
|
||||
if shard_grad.shape[0] > org_grad.shape[0]:
|
||||
shard_grad = shard_grad[:org_grad.shape[0], :]
|
||||
|
||||
if verbose and dist.get_rank() == 0:
|
||||
print(f"'{suffix}' grad: {org_grad}, {shard_grad}")
|
||||
|
||||
|
|
|
@ -1,65 +1,98 @@
|
|||
import pytest
|
||||
import torch
|
||||
from torch import distributed as dist
|
||||
|
||||
import colossalai
|
||||
from colossalai.cluster import ProcessGroupMesh
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer.policies.auto_policy import get_autopolicy
|
||||
from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor
|
||||
from colossalai.testing import (
|
||||
assert_hf_output_close,
|
||||
clear_cache_before_run,
|
||||
parameterize,
|
||||
rerun_if_address_is_in_use,
|
||||
spawn,
|
||||
)
|
||||
from colossalai.shardformer.layer.utils import Randomizer
|
||||
from colossalai.tensor.d_tensor.api import clear_layout_converter
|
||||
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
from tests.test_shardformer.test_model._utils import build_model, check_grad, check_state_dict, run_forward
|
||||
from tests.test_shardformer.test_model._utils import (
|
||||
build_model_from_hybrid_plugin,
|
||||
check_grad,
|
||||
check_loss,
|
||||
check_output_hidden_state,
|
||||
check_weight,
|
||||
run_forward_backward_with_hybrid_plugin,
|
||||
)
|
||||
|
||||
|
||||
def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
|
||||
# unwarp model
|
||||
def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config):
|
||||
|
||||
org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = \
|
||||
build_model_from_hybrid_plugin(model_fn, loss_fn, test_config)
|
||||
|
||||
org_loss, org_output, sharded_loss, sharded_output = \
|
||||
run_forward_backward_with_hybrid_plugin(
|
||||
org_model,
|
||||
sharded_model,
|
||||
sharded_optimizer,
|
||||
data_gen_fn,
|
||||
output_transform_fn,
|
||||
criterion,
|
||||
booster)
|
||||
stage_manager = booster.plugin.stage_manager
|
||||
tp_group = booster.plugin.tp_group
|
||||
# check last hidden state & loss
|
||||
if stage_manager is None or stage_manager.is_last_stage():
|
||||
if org_model.__class__.__name__ == 'BertModel':
|
||||
check_output_hidden_state(org_output, sharded_output, stage_manager, atol=1e-5, rtol=1e-3)
|
||||
|
||||
check_loss(org_loss, sharded_loss, atol=1e-5, rtol=1e-3)
|
||||
# unwrap model
|
||||
if org_model.__class__.__name__ == 'BertModel':
|
||||
bert = org_model
|
||||
sharded_bert = sharded_model
|
||||
sharded_bert = sharded_model.unwrap()
|
||||
else:
|
||||
bert = org_model.bert
|
||||
sharded_bert = sharded_model.bert
|
||||
sharded_bert = sharded_model.unwrap().bert
|
||||
|
||||
# check forward
|
||||
org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn,
|
||||
output_transform_fn, loss_fn)
|
||||
assert_hf_output_close(org_output, shard_output)
|
||||
col_layer_for_check = ['encoder.layer[0].output.dense']
|
||||
row_layer_for_check = ['embeddings.word_embeddings', 'encoder.layer[0].intermediate.dense']
|
||||
|
||||
# do backward
|
||||
org_loss.backward()
|
||||
shard_loss.backward()
|
||||
if stage_manager is None or stage_manager.is_first_stage():
|
||||
#check_weight(bert.embeddings.word_embeddings, sharded_bert.embeddings.word_embeddings, tp_group, atol=1e-5, rtol=1e-3)
|
||||
#check_weight(bert.encoder.layer[0].attention.self.query, sharded_bert.encoder.layer[0].attention.self.query, tp_group, atol=5e-3, rtol=1e-3)
|
||||
check_grad(bert, sharded_bert, col_layer_for_check, tp_group, atol=1e-4, rtol=1e-3, dim=1, verbose=False)
|
||||
check_grad(bert, sharded_bert, row_layer_for_check, tp_group, atol=1e-4, rtol=1e-3, dim=0, verbose=False)
|
||||
|
||||
assert torch.allclose(org_loss, shard_loss,
|
||||
atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}"
|
||||
# check weights after optimizer.step()
|
||||
org_optimizer.step()
|
||||
sharded_optimizer.step()
|
||||
if stage_manager is None or stage_manager.is_first_stage():
|
||||
check_weight(bert, sharded_bert, col_layer_for_check, tp_group, atol=5e-3, rtol=1e-3, dim=1, verbose=False)
|
||||
|
||||
# check grad
|
||||
col_layer_for_check = ['encoder.layer[0].attention.self.query', 'embeddings.word_embeddings']
|
||||
row_layer_for_check = ['encoder.layer[0].attention.output.dense']
|
||||
check_grad(bert, sharded_bert, col_layer_for_check, atol=1e-7, rtol=1e-3, dim=0, verbose=False)
|
||||
check_grad(bert, sharded_bert, row_layer_for_check, atol=1e-7, rtol=1e-3, dim=1, verbose=False)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
@parameterize('enable_fused_normalization', [True, False])
|
||||
@parameterize('enable_tensor_parallelism', [True, False])
|
||||
@parameterize('enable_flash_attention', [True, False])
|
||||
@parameterize('enable_jit_fused', [True, False])
|
||||
@parameterize('use_lazy_init', [False, True])
|
||||
def run_bert_test(enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention, enable_jit_fused,
|
||||
use_lazy_init):
|
||||
@parameterize('test_config', [{
|
||||
'tp_size': 1,
|
||||
'pp_size': 2,
|
||||
'num_microbatches': 4,
|
||||
'use_lazy_init': True
|
||||
}, {
|
||||
'tp_size': 2,
|
||||
'pp_size': 2,
|
||||
'num_microbatches': 4,
|
||||
'enable_fused_normalization': False,
|
||||
'use_lazy_init': False
|
||||
}, {
|
||||
'tp_size': 4,
|
||||
'pp_size': 1,
|
||||
'enable_fused_normalization': True,
|
||||
'use_lazy_init': False
|
||||
}])
|
||||
def run_bert_test(test_config):
|
||||
|
||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_bert')
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism,
|
||||
enable_flash_attention, enable_jit_fused, use_lazy_init)
|
||||
check_state_dict(org_model, sharded_model, name=name)
|
||||
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
|
||||
test_config['precision'] = 'float'
|
||||
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
|
||||
|
||||
clear_layout_converter()
|
||||
Randomizer.reset_index()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
|
@ -73,7 +106,7 @@ def check_bert(rank, world_size, port):
|
|||
@rerun_if_address_is_in_use()
|
||||
@clear_cache_before_run()
|
||||
def test_bert():
|
||||
spawn(check_bert, 2)
|
||||
spawn(check_bert, 4)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -1,107 +0,0 @@
|
|||
import pytest
|
||||
import torch
|
||||
|
||||
import colossalai
|
||||
from colossalai.cluster import ProcessGroupMesh
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer.policies.auto_policy import get_autopolicy
|
||||
from colossalai.shardformer.shard import ShardConfig
|
||||
from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor
|
||||
from colossalai.testing import (
|
||||
assert_hf_output_close,
|
||||
clear_cache_before_run,
|
||||
parameterize,
|
||||
rerun_if_address_is_in_use,
|
||||
spawn,
|
||||
)
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
from tests.test_shardformer.test_model._utils import build_model, build_pipeline_model, run_forward
|
||||
|
||||
|
||||
def check_bert_model_policy(name, model: torch.nn.Module, stage_manager: PipelineStageManager):
|
||||
stage_manager = stage_manager
|
||||
policy = get_autopolicy(model)
|
||||
policy.set_model(model)
|
||||
model_config = ShardConfig(pipeline_stage_manager=stage_manager, enable_tensor_parallelism=False)
|
||||
policy.set_shard_config(model_config)
|
||||
layers = policy.get_held_layers()
|
||||
if stage_manager.is_first_stage():
|
||||
assert len(layers) == 1 + 1
|
||||
else:
|
||||
if name == "transformers_bert":
|
||||
assert len(layers) == 1 + 1
|
||||
elif name in [
|
||||
"transformers_bert_for_sequence_classification", "transformers_bert_for_token_classification",
|
||||
"transformers_bert_for_mcq"
|
||||
]:
|
||||
assert len(layers) == 1 + 3
|
||||
else:
|
||||
assert len(layers) == 1 + 2
|
||||
|
||||
|
||||
def check_bert_model_pipeline_forward(name, sharded_model, stage_manager: PipelineStageManager):
|
||||
if name == 'transformers_bert_for_mcq':
|
||||
x = torch.randint(0, 1000, (2, 3, 3)).cuda()
|
||||
attention_mask = torch.ones_like(x).cuda()
|
||||
if stage_manager.stage == 0:
|
||||
output = sharded_model(input_ids=x, attention_mask=attention_mask, stage_manager=stage_manager)
|
||||
assert output['hidden_states'].shape == (6, 3, 128)
|
||||
else:
|
||||
hidden_states = torch.randint(0, 1000, (6, 3, 128)).to(torch.float32).cuda()
|
||||
output = sharded_model(input_ids=x,
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
stage_manager=stage_manager)
|
||||
assert output[0].shape == (2, 3)
|
||||
else:
|
||||
x = torch.randint(0, 1000, (2, 3)).cuda()
|
||||
# one batch, 2 single sentences, each sentence has 3 tokens
|
||||
hidden_states = torch.randint(0, 1000, (2, 3, 128)).to(torch.float32).cuda()
|
||||
if stage_manager.stage == 0:
|
||||
attention_mask = torch.ones_like(x).cuda()
|
||||
output = sharded_model(input_ids=x, attention_mask=attention_mask, stage_manager=stage_manager)
|
||||
assert output['hidden_states'].shape == (2, 3, 128)
|
||||
else:
|
||||
attention_mask = torch.ones((2, 3)).cuda()
|
||||
output = sharded_model(hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
stage_manager=stage_manager)
|
||||
assert output[0].shape[0] == 2
|
||||
|
||||
|
||||
@parameterize('enable_fused_normalization', [False])
|
||||
@parameterize('enable_tensor_parallelism', [False])
|
||||
@parameterize('use_lazy_init', [False])
|
||||
#TODO: merge this into test_shard_bert
|
||||
def run_bert_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init):
|
||||
PP_DIM = 0
|
||||
PP_SIZE = 2
|
||||
pg_mesh = ProcessGroupMesh(PP_SIZE)
|
||||
stage_manager = PipelineStageManager(pg_mesh, PP_DIM)
|
||||
|
||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_bert')
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
org_model, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization,
|
||||
enable_tensor_parallelism, use_lazy_init)
|
||||
check_bert_model_policy(name, org_model, stage_manager)
|
||||
check_bert_model_pipeline_forward(name, sharded_model, stage_manager)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
def check_bert(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
run_bert_test()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
@clear_cache_before_run()
|
||||
def test_bert():
|
||||
spawn(check_bert, 2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_bert()
|
Loading…
Reference in New Issue