[shardformer] fix opt test hanging (#4521)

* [shardformer] fix opt test hanging

* fix

* test

* test

* test

* fix test

* fix test

* remove print

* add fix
pull/4544/head
flybird11111 1 year ago committed by GitHub
parent e241b74f24
commit d367b88785
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -103,21 +103,21 @@ class OPTPolicy(Policy):
target_key=OPTDecoderLayer)
# use flash attention
# if self.shard_config.enable_flash_attention:
# self.append_or_create_method_replacement(description={
# 'forward': get_opt_flash_attention_forward(),
# },
# policy=policy,
# target_key=OPTAttention)
if self.shard_config.enable_flash_attention:
self.append_or_create_method_replacement(description={
'forward': get_opt_flash_attention_forward(),
},
policy=policy,
target_key=OPTAttention)
# use jit fused operator
# if self.shard_config.enable_jit_fused:
# self.append_or_create_method_replacement(description={
# 'forward': get_jit_fused_opt_decoder_layer_forward(),
# 'dropout_add': get_jit_fused_dropout_add_func(),
# },
# policy=policy,
# target_key=OPTDecoderLayer)
if self.shard_config.enable_jit_fused:
self.append_or_create_method_replacement(description={
'forward': get_jit_fused_opt_decoder_layer_forward(),
'dropout_add': get_jit_fused_dropout_add_func(),
},
policy=policy,
target_key=OPTDecoderLayer)
return policy

@ -184,24 +184,33 @@ class T5BasePolicy(Policy):
# use flash attention
if self.shard_config.enable_flash_attention:
policy[T5Attention] = ModulePolicyDescription(method_replacement={
self.append_or_create_method_replacement(description={
'forward': get_t5_flash_attention_forward(),
})
},
policy=policy,
target_key=T5Attention)
# use jit operator
if self.shard_config.enable_jit_fused:
policy[T5LayerFF] = ModulePolicyDescription(method_replacement={
self.append_or_create_method_replacement(description={
'forward': get_jit_fused_T5_layer_ff_forward(),
'dropout_add': get_jit_fused_dropout_add_func(),
})
policy[T5LayerSelfAttention] = ModulePolicyDescription(method_replacement={
},
policy=policy,
target_key=T5LayerFF)
self.append_or_create_method_replacement(description={
'forward': get_T5_layer_self_attention_forward(),
'dropout_add': get_jit_fused_dropout_add_func(),
})
policy[T5LayerCrossAttention] = ModulePolicyDescription(method_replacement={
},
policy=policy,
target_key=T5LayerSelfAttention)
self.append_or_create_method_replacement(description={
'forward': get_T5_layer_cross_attention_forward(),
'dropout_add': get_jit_fused_dropout_add_func(),
})
},
policy=policy,
target_key=T5LayerCrossAttention)
return policy
def postprocess(self):

@ -56,9 +56,6 @@ class WhisperPolicy(Policy):
self.shard_config.enable_sequence_parallelism = False
warnings.warn(
"Whisper dosen't support sequence parallelism now, will ignore the sequence parallelism flag.")
if self.shard_config.enable_jit_fused:
self.shard_config.enable_jit_fused = False
warnings.warn("Whisper dosen't support jit fused operator now, will ignore the jit fused flag.")
if self.shard_config.enable_tensor_parallelism:
policy[WhisperEncoderLayer] = ModulePolicyDescription(attribute_replacement={
@ -212,6 +209,21 @@ class WhisperPolicy(Policy):
policy=policy,
target_key=WhisperAttention)
# use jit fused operator
if self.shard_config.enable_jit_fused:
self.append_or_create_method_replacement(description={
'forward': get_jit_fused_whisper_decoder_layer_forward(),
'dropout_add': get_jit_fused_dropout_add_func(),
},
policy=policy,
target_key=WhisperDecoderLayer)
self.append_or_create_method_replacement(description={
'forward': get_jit_fused_whisper_encoder_layer_forward(),
'dropout_add': get_jit_fused_dropout_add_func(),
},
policy=policy,
target_key=WhisperEncoderLayer)
return policy
def add_lm_head_policy(self, base_policy):

@ -237,6 +237,43 @@ def check_weight(org_model: Module,
f"shard model weight {suffix} is not equal to origin model weight\n{org_weight}\n{sharded_weight}"
def get_grad_tensors_for_check(org_model: Module,
sharded_model: Module,
layer_suffix: List[str],
tp_group: ProcessGroup = None,
dim: int = 0,
atol: float = 1e-5,
rtol: float = 1e-3,
verbose: bool = False,
name: str = None):
grad_to_check = {}
for suffix in layer_suffix:
org_grad = getattr_(org_model, suffix).weight.grad
shard_grad = getattr_(sharded_model, suffix).weight.grad
shard_weight = getattr_(sharded_model, suffix).weight
if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight):
shard_grad_list = [torch.zeros_like(shard_grad).to('cuda') for _ in range(dist.get_world_size(tp_group))]
dist.all_gather(shard_grad_list, shard_grad, tp_group)
shard_grad = torch.cat(shard_grad_list, dim=dim)
# embedding may be resized when using tensor parallel
if shard_grad.shape[0] > org_grad.shape[0]:
shard_grad = shard_grad[:org_grad.shape[0], :]
if verbose and dist.get_rank() == 0:
print(f"'{suffix}' grad: {org_grad}, {shard_grad}")
grad_to_check[suffix] = {
"org_grad": org_grad.float(),
"shard_grad": shard_grad.float(),
"rtol": rtol,
"atol": atol
}
return grad_to_check
# used by sam/blip2
def check_grad(org_model: Module,
sharded_model: Module,
layer_suffix: List[str],
@ -275,3 +312,18 @@ def unwrap_model(module: Module,
if module.__class__.__name__ == base_model_class_name:
return module
return getattr(module, base_model_attribute_name, None)
def check_all_grad_tensors(check_tensors):
"""
"org_grad": tensor to be compared from the original model
"shard_grad": tensor to be compared from the sharded model
"""
for suffix, check_info in check_tensors.items():
org_grad = check_info["org_grad"]
shard_grad = check_info["shard_grad"]
rtol = check_info["rtol"]
atol = check_info["atol"]
assert torch.allclose(
org_grad, shard_grad, atol=atol, rtol=rtol
), f"error attribute '{suffix}', orgin model grad is not equal to shard model grad\n{org_grad}\n{shard_grad}"

@ -10,10 +10,11 @@ from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_ad
from tests.kit.model_zoo import model_zoo
from tests.test_shardformer.test_model._utils import (
build_model_from_hybrid_plugin,
check_grad,
check_all_grad_tensors,
check_loss,
check_output_hidden_state,
check_weight,
get_grad_tensors_for_check,
run_forward_backward_with_hybrid_plugin,
unwrap_model,
)
@ -33,18 +34,9 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
output_transform_fn,
criterion,
booster)
stage_manager = booster.plugin.stage_manager
tp_group = booster.plugin.tp_group
# check last hidden state & loss
if stage_manager is None or stage_manager.is_last_stage():
if test_config['precision'] == 'fp32':
atol, rtol = 1e-5, 1e-3
else:
atol, rtol = 5e-3, 5e-3
if org_model.__class__.__name__ == 'BertModel':
check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol)
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
bert = unwrap_model(org_model, 'BertModel', 'bert')
sharded_bert = unwrap_model(sharded_model, 'BertModel', 'bert')
@ -52,17 +44,48 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
col_layer_for_check = ['encoder.layer[0].output.dense']
row_layer_for_check = ['embeddings.word_embeddings', 'encoder.layer[0].intermediate.dense']
# Save gradient tensors for comparison between the original model and the sharded model before optimizer step.
grads_to_check = {}
if test_config['precision'] == 'fp32':
atol, rtol = 1e-4, 1e-3
else:
atol, rtol = 5e-3, 5e-3
if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0:
check_grad(bert, sharded_bert, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False)
check_grad(bert, sharded_bert, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False)
# check weights after optimizer.step()
col_layer_grads = get_grad_tensors_for_check(bert,
sharded_bert,
col_layer_for_check,
tp_group,
atol=atol,
rtol=rtol,
dim=1,
verbose=False)
row_layer_grads = get_grad_tensors_for_check(bert,
sharded_bert,
row_layer_for_check,
tp_group,
atol=atol,
rtol=rtol,
dim=0,
verbose=False)
grads_to_check.update(col_layer_grads)
grads_to_check.update(row_layer_grads)
# optimizer executes step
org_optimizer.step()
sharded_optimizer.step()
# check last hidden state & loss
if stage_manager is None or stage_manager.is_last_stage():
if test_config['precision'] == 'fp32':
atol, rtol = 1e-5, 1e-3
else:
atol, rtol = 5e-3, 5e-3
if org_model.__class__.__name__ == 'BertModel':
check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol)
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
# check weights
if test_config['precision'] == 'fp32':
atol, rtol = 5e-3, 1e-3
else:
@ -70,6 +93,9 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
if stage_manager is None or stage_manager.is_first_stage():
check_weight(bert, sharded_bert, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False)
# check grads
check_all_grad_tensors(grads_to_check)
torch.cuda.empty_cache()

@ -9,10 +9,11 @@ from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_ad
from tests.kit.model_zoo import model_zoo
from tests.test_shardformer.test_model._utils import (
build_model_from_hybrid_plugin,
check_grad,
check_all_grad_tensors,
check_loss,
check_output_hidden_state,
check_weight,
get_grad_tensors_for_check,
run_forward_backward_with_hybrid_plugin,
unwrap_model,
)
@ -36,35 +37,54 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
stage_manager = booster.plugin.stage_manager
tp_group = booster.plugin.tp_group
# check last hidden state & loss
if stage_manager is None or stage_manager.is_last_stage():
if test_config['precision'] == 'fp32':
atol, rtol = 1e-5, 1e-3
else:
atol, rtol = 5e-3, 5e-3
if org_model.__class__.__name__ == 'BloomModel':
check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol)
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
# unwrap model
bloom = unwrap_model(org_model, 'BloomModel', 'transformer')
sharded_bloom = unwrap_model(sharded_model, 'BloomModel', 'transformer')
# check grad
row_layer_for_check = ['h[0].self_attention.query_key_value', 'word_embeddings']
col_layer_for_check = ['h[0].self_attention.dense']
# Save gradient tensors for comparison between the original model and the sharded model.
grads_to_check = {}
if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0:
if test_config['precision'] == 'fp32':
atol, rtol = 1e-6, 1e-5
else:
atol, rtol = 5e-3, 5e-3
check_grad(bloom, sharded_bloom, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False)
check_grad(bloom, sharded_bloom, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False)
# check weights after optimizer.step()
row_layer_grads = get_grad_tensors_for_check(bloom,
sharded_bloom,
row_layer_for_check,
tp_group,
atol=atol,
rtol=rtol,
dim=0,
verbose=False)
col_layer_grads = get_grad_tensors_for_check(bloom,
sharded_bloom,
col_layer_for_check,
tp_group,
atol=atol,
rtol=rtol,
dim=1,
verbose=False)
grads_to_check.update(col_layer_grads)
grads_to_check.update(row_layer_grads)
# optimizer executes step
org_optimizer.step()
sharded_optimizer.step()
# check last hidden state & loss
if stage_manager is None or stage_manager.is_last_stage():
if test_config['precision'] == 'fp32':
atol, rtol = 1e-5, 1e-3
else:
atol, rtol = 5e-3, 5e-3
if org_model.__class__.__name__ == 'BloomModel':
check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol)
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
if stage_manager is None or stage_manager.is_first_stage():
if test_config['precision'] == 'fp32':
atol, rtol = 1e-4, 1e-3
@ -72,6 +92,9 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
atol, rtol = 5e-3, 5e-3
check_weight(bloom, sharded_bloom, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False)
# check grads
check_all_grad_tensors(grads_to_check)
torch.cuda.empty_cache()

@ -9,10 +9,11 @@ from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_ad
from tests.kit.model_zoo import model_zoo
from tests.test_shardformer.test_model._utils import (
build_model_from_hybrid_plugin,
check_grad,
check_all_grad_tensors,
check_loss,
check_output_hidden_state,
check_weight,
get_grad_tensors_for_check,
run_forward_backward_with_hybrid_plugin,
unwrap_model,
)
@ -36,31 +37,21 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
stage_manager = booster.plugin.stage_manager
tp_group = booster.plugin.tp_group
# check last hidden state & loss
if stage_manager is None or stage_manager.is_last_stage():
if test_config['precision'] == 'fp32':
atol, rtol = 1e-5, 1e-3
else:
atol, rtol = 5e-3, 5e-3
if org_model.__class__.__name__ == 'ChatGLMModel':
check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol, dim=1)
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
# unwrap model
chatglm_model = unwrap_model(org_model, 'ChatGLMModel', 'transformer')
shard_chatglm_model = unwrap_model(sharded_model, 'ChatGLMModel', 'transformer')
# check grad
row_layer_for_check = ['encoder.layers[0].self_attention.query_key_value', 'embedding.word_embeddings']
col_layer_for_check = ['encoder.layers[0].self_attention.dense']
# Save gradient tensors for comparison between the original model and the sharded model.
grads_to_check = {}
if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0:
if test_config['precision'] == 'fp32':
atol, rtol = 1e-6, 1e-3
else:
atol, rtol = 5e-3, 5e-3
check_grad(chatglm_model,
row_layer_grads = get_grad_tensors_for_check(chatglm_model,
shard_chatglm_model,
row_layer_for_check,
tp_group,
@ -69,7 +60,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
dim=0,
verbose=False)
check_grad(chatglm_model,
col_layer_grads = get_grad_tensors_for_check(chatglm_model,
shard_chatglm_model,
col_layer_for_check,
tp_group,
@ -77,10 +68,26 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
rtol=rtol,
dim=1,
verbose=False)
grads_to_check.update(col_layer_grads)
grads_to_check.update(row_layer_grads)
# check weights after optimizer.step()
# optimizer executes step
org_optimizer.step()
sharded_optimizer.step()
# check last hidden state & loss
if stage_manager is None or stage_manager.is_last_stage():
if test_config['precision'] == 'fp32':
atol, rtol = 1e-5, 1e-3
else:
atol, rtol = 5e-3, 5e-3
if org_model.__class__.__name__ == 'ChatGLMModel':
check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol, dim=1)
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
# check weights
if stage_manager is None or stage_manager.is_first_stage():
if test_config['precision'] == 'fp32':
atol, rtol = 1e-4, 1e-3
@ -95,6 +102,9 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
dim=1,
verbose=False)
# check grads
check_all_grad_tensors(grads_to_check)
torch.cuda.empty_cache()

@ -9,10 +9,11 @@ from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_ad
from tests.kit.model_zoo import model_zoo
from tests.test_shardformer.test_model._utils import (
build_model_from_hybrid_plugin,
check_grad,
check_all_grad_tensors,
check_loss,
check_output_hidden_state,
check_weight,
get_grad_tensors_for_check,
run_forward_backward_with_hybrid_plugin,
unwrap_model,
)
@ -36,18 +37,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
stage_manager = booster.plugin.stage_manager
tp_group = booster.plugin.tp_group
# check last hidden state & loss
if stage_manager is None or stage_manager.is_last_stage():
if test_config['precision'] == 'fp32':
atol, rtol = 1e-5, 1e-3
else:
atol, rtol = 5e-3, 5e-3
if org_model.__class__.__name__ == 'GPT2Model':
check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol)
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
# unwrap model
gpt2 = unwrap_model(org_model, 'GPT2Model', 'transformer')
sharded_gpt2 = unwrap_model(sharded_model, 'GPT2Model', 'transformer')
@ -55,18 +44,49 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
col_layer_for_check = ['h[0].mlp.c_fc']
row_layer_for_check = ['wte', 'h[0].mlp.c_proj']
# check grad
# Save gradient tensors for comparison between the original model and the sharded model.
grads_to_check = {}
if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0:
if test_config['precision'] == 'fp32':
atol, rtol = 1e-4, 1e-3
else:
atol, rtol = 5e-3, 5e-3
check_grad(gpt2, sharded_gpt2, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False)
check_grad(gpt2, sharded_gpt2, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False)
# check weights after optimizer.step()
col_layer_grads = get_grad_tensors_for_check(gpt2,
sharded_gpt2,
col_layer_for_check,
tp_group,
atol=atol,
rtol=rtol,
dim=1,
verbose=False)
row_layer_grads = get_grad_tensors_for_check(gpt2,
sharded_gpt2,
row_layer_for_check,
tp_group,
atol=atol,
rtol=rtol,
dim=0,
verbose=False)
grads_to_check.update(col_layer_grads)
grads_to_check.update(row_layer_grads)
# optimizer executes step
org_optimizer.step()
sharded_optimizer.step()
# check last hidden state & loss
if stage_manager is None or stage_manager.is_last_stage():
if test_config['precision'] == 'fp32':
atol, rtol = 1e-5, 1e-3
else:
atol, rtol = 5e-3, 5e-3
if org_model.__class__.__name__ == 'GPT2Model':
check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol)
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
# check weights
if stage_manager is None or stage_manager.is_first_stage():
if test_config['precision'] == 'fp32':
atol, rtol = 5e-3, 1e-3
@ -74,6 +94,9 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
atol, rtol = 5e-3, 5e-3
check_weight(gpt2, sharded_gpt2, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False)
# check grads
check_all_grad_tensors(grads_to_check)
torch.cuda.empty_cache()

@ -12,10 +12,11 @@ from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_ad
from tests.kit.model_zoo import model_zoo
from tests.test_shardformer.test_model._utils import (
build_model_from_hybrid_plugin,
check_grad,
check_all_grad_tensors,
check_loss,
check_output_hidden_state,
check_weight,
get_grad_tensors_for_check,
run_forward_backward_with_hybrid_plugin,
unwrap_model,
)
@ -41,30 +42,21 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
stage_manager = booster.plugin.stage_manager
tp_group = booster.plugin.tp_group
# check last hidden state & loss
if stage_manager is None or stage_manager.is_last_stage():
if test_config['precision'] == 'fp32':
atol, rtol = 1e-5, 1e-3
else:
atol, rtol = 5e-3, 5e-3
if org_model.__class__.__name__ == 'LlamaModel':
check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol)
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
# unwrap model
llama_model = unwrap_model(org_model, 'LlamaModel', 'model')
shard_llama_model = unwrap_model(sharded_model, 'LlamaModel', 'model')
# check grad
row_layer_for_check = ['layers[0].self_attn.q_proj', 'embed_tokens']
col_layer_for_check = ['layers[0].self_attn.o_proj']
# Save gradient tensors for comparison between the original model and the sharded model before optimizer step.
grads_to_check = {}
if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0:
if test_config['precision'] == 'fp32':
atol, rtol = 1e-6, 1e-4
else:
atol, rtol = 5e-3, 5e-3
check_grad(llama_model,
row_layer_grads = get_grad_tensors_for_check(llama_model,
shard_llama_model,
row_layer_for_check,
tp_group,
@ -72,7 +64,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
rtol=rtol,
dim=0,
verbose=False)
check_grad(llama_model,
col_layer_grads = get_grad_tensors_for_check(llama_model,
shard_llama_model,
col_layer_for_check,
tp_group,
@ -80,10 +72,26 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
rtol=rtol,
dim=1,
verbose=False)
grads_to_check.update(col_layer_grads)
grads_to_check.update(row_layer_grads)
# check weights after optimizer.step()
# optimizer executes step
org_optimizer.step()
sharded_optimizer.step()
# check last hidden state & loss
if stage_manager is None or stage_manager.is_last_stage():
if test_config['precision'] == 'fp32':
atol, rtol = 1e-5, 1e-3
else:
atol, rtol = 5e-3, 5e-3
if org_model.__class__.__name__ == 'LlamaModel':
check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol)
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
# check weights
if stage_manager is None or stage_manager.is_first_stage():
if test_config['precision'] == 'fp32':
atol, rtol = 1e-4, 1e-3
@ -98,6 +106,9 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
dim=1,
verbose=False)
# check grads
check_all_grad_tensors(grads_to_check)
torch.cuda.empty_cache()

@ -11,10 +11,11 @@ from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_ad
from tests.kit.model_zoo import model_zoo
from tests.test_shardformer.test_model._utils import (
build_model_from_hybrid_plugin,
check_grad,
check_all_grad_tensors,
check_loss,
check_output_hidden_state,
check_weight,
get_grad_tensors_for_check,
run_forward_backward_with_hybrid_plugin,
unwrap_model,
)
@ -40,30 +41,21 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
stage_manager = booster.plugin.stage_manager
tp_group = booster.plugin.tp_group
# check last hidden state & loss
if stage_manager is None or stage_manager.is_last_stage():
if test_config['precision'] == 'fp32':
atol, rtol = 1e-5, 1e-3
else:
atol, rtol = 5e-3, 5e-3
if org_model.__class__.__name__ == 'OPTModel':
check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol)
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
# unwrap model
opt_model = unwrap_model(org_model, 'OPTModel', 'model')
shard_opt_model = unwrap_model(sharded_model, 'OPTModel', 'model')
# check grad
row_layer_for_check = ['decoder.layers[0].self_attn.q_proj', 'decoder.embed_tokens'] # 'decoder.embed_tokens'
col_layer_for_check = ['decoder.layers[0].self_attn.out_proj']
# Save gradient tensors for comparison between the original model and the sharded model.
grads_to_check = {}
if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0:
if test_config['precision'] == 'fp32':
atol, rtol = 1e-6, 1e-3
else:
atol, rtol = 3e-2, 3e-2
check_grad(opt_model,
atol, rtol = 4e-2, 4e-2
row_layer_grads = get_grad_tensors_for_check(opt_model,
shard_opt_model,
row_layer_for_check,
tp_group,
@ -71,7 +63,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
rtol=rtol,
dim=0,
verbose=False)
check_grad(opt_model,
col_layer_grads = get_grad_tensors_for_check(opt_model,
shard_opt_model,
col_layer_for_check,
tp_group,
@ -79,10 +71,25 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
rtol=rtol,
dim=1,
verbose=False)
grads_to_check.update(col_layer_grads)
grads_to_check.update(row_layer_grads)
# check weights after optimizer.step()
# optimizer executes step
org_optimizer.step()
sharded_optimizer.step()
# check last hidden state & loss
if stage_manager is None or stage_manager.is_last_stage():
if test_config['precision'] == 'fp32':
atol, rtol = 1e-5, 1e-3
else:
atol, rtol = 5e-3, 5e-3
if org_model.__class__.__name__ == 'OPTModel':
check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol)
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
# check weights
if stage_manager is None or stage_manager.is_first_stage():
if test_config['precision'] == 'fp32':
atol, rtol = 1e-3, 1e-3
@ -97,6 +104,9 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
dim=1,
verbose=False)
# check grads
check_all_grad_tensors(grads_to_check)
torch.cuda.empty_cache()

@ -10,10 +10,11 @@ from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_ad
from tests.kit.model_zoo import model_zoo
from tests.test_shardformer.test_model._utils import (
build_model_from_hybrid_plugin,
check_grad,
check_all_grad_tensors,
check_loss,
check_output_hidden_state,
check_weight,
get_grad_tensors_for_check,
run_forward_backward_with_hybrid_plugin,
unwrap_model,
)
@ -37,42 +38,55 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
stage_manager = booster.plugin.stage_manager
tp_group = booster.plugin.tp_group
# check last hidden state & loss
if stage_manager is None or stage_manager.is_last_stage():
if test_config['precision'] == 'fp32':
atol, rtol = 1e-5, 1e-3
else:
atol, rtol = 5e-3, 5e-3
if org_model.__class__.__name__ != 'T5ForConditionalGeneration':
check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol)
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
# unwrap model
t5 = unwrap_model(org_model)
sharded_t5 = unwrap_model(sharded_model)
row_layer_for_check = ['shared', 'encoder.block[0].layer[0].SelfAttention.q']
# check grad
# Save gradient tensors for comparison between the original model and the sharded model before optimizer step.
grads_to_check = {}
if test_config['precision'] == 'fp32':
atol, rtol = 1e-5, 1e-3
else:
atol, rtol = 5e-3, 5e-3
if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0:
check_grad(t5, sharded_t5, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0)
# check weights after optimizer.step()
row_layer_grads = get_grad_tensors_for_check(t5,
sharded_t5,
row_layer_for_check,
tp_group,
atol=atol,
rtol=rtol,
dim=0)
grads_to_check.update(row_layer_grads)
# optimizer executes step
org_optimizer.step()
sharded_optimizer.step()
# check last hidden state & loss
if stage_manager is None or stage_manager.is_last_stage():
if test_config['precision'] == 'fp32':
atol, rtol = 1e-4, 1e-3
atol, rtol = 1e-5, 1e-3
else:
atol, rtol = 5e-3, 5e-3
if org_model.__class__.__name__ != 'T5ForConditionalGeneration':
check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol)
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
# check weights
if test_config['precision'] == 'fp32':
atol, rtol = 5e-4, 1e-3
else:
atol, rtol = 5e-3, 5e-3
if stage_manager is None or stage_manager.is_first_stage():
check_weight(t5, sharded_t5, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False)
# check grads
check_all_grad_tensors(grads_to_check)
torch.cuda.empty_cache()

@ -9,10 +9,11 @@ from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_ad
from tests.kit.model_zoo import model_zoo
from tests.test_shardformer.test_model._utils import (
build_model_from_hybrid_plugin,
check_grad,
check_all_grad_tensors,
check_loss,
check_output_hidden_state,
check_weight,
get_grad_tensors_for_check,
run_forward_backward_with_hybrid_plugin,
unwrap_model,
)
@ -36,17 +37,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
stage_manager = booster.plugin.stage_manager
tp_group = booster.plugin.tp_group
# check last hidden state & loss
if stage_manager is None or stage_manager.is_last_stage():
if test_config['precision'] == 'fp32':
atol, rtol = 1e-5, 1e-3
else:
atol, rtol = 5e-3, 5e-3
if org_model.__class__.__name__ == 'ViTModel':
check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol)
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
# unwrap model
vit_model = unwrap_model(org_model, 'ViTModel', 'vit')
shard_vit_model = unwrap_model(sharded_model, 'ViTModel', 'vit')
@ -54,12 +44,15 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
# check grad
row_layer_for_check = ['encoder.layer[0].attention.attention.query', 'embeddings.patch_embeddings.projection']
col_layer_for_check = ['encoder.layer[0].attention.output.dense']
# Save gradient tensors for comparison between the original model and the sharded model before optimizer step.
grads_to_check = {}
if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0:
if test_config['precision'] == 'fp32':
atol, rtol = 1e-5, 1e-3
else:
atol, rtol = 5e-3, 5e-3
check_grad(vit_model,
row_layer_grads = get_grad_tensors_for_check(vit_model,
shard_vit_model,
row_layer_for_check,
tp_group,
@ -67,7 +60,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
rtol=rtol,
dim=0,
verbose=False)
check_grad(vit_model,
col_layer_grads = get_grad_tensors_for_check(vit_model,
shard_vit_model,
col_layer_for_check,
tp_group,
@ -75,10 +68,25 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
rtol=rtol,
dim=1,
verbose=False)
grads_to_check.update(col_layer_grads)
grads_to_check.update(row_layer_grads)
# check weights after optimizer.step()
# optimizer executes step
org_optimizer.step()
sharded_optimizer.step()
# check last hidden state & loss
if stage_manager is None or stage_manager.is_last_stage():
if test_config['precision'] == 'fp32':
atol, rtol = 1e-5, 1e-3
else:
atol, rtol = 5e-3, 5e-3
if org_model.__class__.__name__ == 'ViTModel':
check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol)
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
# check weights
if stage_manager is None or stage_manager.is_first_stage():
if test_config['precision'] == 'fp32':
atol, rtol = 5e-3, 1e-3
@ -93,6 +101,9 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
dim=1,
verbose=False)
# check grads
check_all_grad_tensors(grads_to_check)
torch.cuda.empty_cache()

@ -15,10 +15,11 @@ from colossalai.testing import (
from tests.kit.model_zoo import model_zoo
from tests.test_shardformer.test_model._utils import (
build_model_from_hybrid_plugin,
check_grad,
check_all_grad_tensors,
check_loss,
check_output_hidden_state,
check_weight,
get_grad_tensors_for_check,
run_forward_backward_with_hybrid_plugin,
)
@ -41,18 +42,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
stage_manager = booster.plugin.stage_manager
tp_group = booster.plugin.tp_group
# check last hidden state & loss
if stage_manager is None or stage_manager.is_last_stage():
if test_config['precision'] == 'fp32':
atol, rtol = 2e-4, 2e-4
else:
atol, rtol = 5e-3, 5e-3
if org_model.__class__.__name__ == 'WhisperModel':
check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol)
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
# unwarp the model
if org_model.__class__.__name__ == 'WhisperForConditionalGeneration':
whisper = org_model.model
@ -75,19 +64,48 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
#'decoder.layers[0].self_attn.out_proj'
]
# check weights and gradients
# Save gradient tensors for comparison between the original model and the sharded model before optimizer step.
grads_to_check = {}
if test_config['precision'] == 'fp32':
atol, rtol = 2e-4, 2e-4
else:
atol, rtol = 5e-3, 5e-3
if stage_manager is None or stage_manager.is_first_stage():
check_grad(whisper, sharded_whisper, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1)
check_grad(whisper, sharded_whisper, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0)
row_layer_grads = get_grad_tensors_for_check(whisper,
sharded_whisper,
row_layer_for_check,
tp_group,
atol=atol,
rtol=rtol,
dim=1)
col_layer_grads = get_grad_tensors_for_check(whisper,
sharded_whisper,
col_layer_for_check,
tp_group,
atol=atol,
rtol=rtol,
dim=0)
grads_to_check.update(col_layer_grads)
grads_to_check.update(row_layer_grads)
# check weights after optimizer.step()
# optimizer executes step
org_optimizer.step()
sharded_optimizer.step()
# check last hidden state & loss
if stage_manager is None or stage_manager.is_last_stage():
if test_config['precision'] == 'fp32':
atol, rtol = 2e-4, 2e-4
else:
atol, rtol = 5e-3, 5e-3
if org_model.__class__.__name__ == 'WhisperModel':
check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol)
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
# check weights
if test_config['precision'] == 'fp32':
atol, rtol = 5e-4, 5e-4
else:
@ -110,8 +128,12 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
dim=0,
verbose=False)
# check grads
check_all_grad_tensors(grads_to_check)
torch.cuda.empty_cache()
#TODO fix WhisperForConditionalGeneration enable jit fused operato
# TODOjianghai) fix fp16
@parameterize(

Loading…
Cancel
Save