mirror of https://github.com/hpcaitech/ColossalAI
[shardformer] fix opt test hanging (#4521)
* [shardformer] fix opt test hanging * fix * test * test * test * fix test * fix test * remove print * add fixpull/4544/head
parent
e241b74f24
commit
d367b88785
|
@ -103,21 +103,21 @@ class OPTPolicy(Policy):
|
|||
target_key=OPTDecoderLayer)
|
||||
|
||||
# use flash attention
|
||||
# if self.shard_config.enable_flash_attention:
|
||||
# self.append_or_create_method_replacement(description={
|
||||
# 'forward': get_opt_flash_attention_forward(),
|
||||
# },
|
||||
# policy=policy,
|
||||
# target_key=OPTAttention)
|
||||
if self.shard_config.enable_flash_attention:
|
||||
self.append_or_create_method_replacement(description={
|
||||
'forward': get_opt_flash_attention_forward(),
|
||||
},
|
||||
policy=policy,
|
||||
target_key=OPTAttention)
|
||||
|
||||
# use jit fused operator
|
||||
# if self.shard_config.enable_jit_fused:
|
||||
# self.append_or_create_method_replacement(description={
|
||||
# 'forward': get_jit_fused_opt_decoder_layer_forward(),
|
||||
# 'dropout_add': get_jit_fused_dropout_add_func(),
|
||||
# },
|
||||
# policy=policy,
|
||||
# target_key=OPTDecoderLayer)
|
||||
if self.shard_config.enable_jit_fused:
|
||||
self.append_or_create_method_replacement(description={
|
||||
'forward': get_jit_fused_opt_decoder_layer_forward(),
|
||||
'dropout_add': get_jit_fused_dropout_add_func(),
|
||||
},
|
||||
policy=policy,
|
||||
target_key=OPTDecoderLayer)
|
||||
|
||||
return policy
|
||||
|
||||
|
|
|
@ -184,24 +184,33 @@ class T5BasePolicy(Policy):
|
|||
|
||||
# use flash attention
|
||||
if self.shard_config.enable_flash_attention:
|
||||
policy[T5Attention] = ModulePolicyDescription(method_replacement={
|
||||
self.append_or_create_method_replacement(description={
|
||||
'forward': get_t5_flash_attention_forward(),
|
||||
})
|
||||
},
|
||||
policy=policy,
|
||||
target_key=T5Attention)
|
||||
|
||||
# use jit operator
|
||||
if self.shard_config.enable_jit_fused:
|
||||
policy[T5LayerFF] = ModulePolicyDescription(method_replacement={
|
||||
self.append_or_create_method_replacement(description={
|
||||
'forward': get_jit_fused_T5_layer_ff_forward(),
|
||||
'dropout_add': get_jit_fused_dropout_add_func(),
|
||||
})
|
||||
policy[T5LayerSelfAttention] = ModulePolicyDescription(method_replacement={
|
||||
},
|
||||
policy=policy,
|
||||
target_key=T5LayerFF)
|
||||
self.append_or_create_method_replacement(description={
|
||||
'forward': get_T5_layer_self_attention_forward(),
|
||||
'dropout_add': get_jit_fused_dropout_add_func(),
|
||||
})
|
||||
policy[T5LayerCrossAttention] = ModulePolicyDescription(method_replacement={
|
||||
},
|
||||
policy=policy,
|
||||
target_key=T5LayerSelfAttention)
|
||||
self.append_or_create_method_replacement(description={
|
||||
'forward': get_T5_layer_cross_attention_forward(),
|
||||
'dropout_add': get_jit_fused_dropout_add_func(),
|
||||
})
|
||||
},
|
||||
policy=policy,
|
||||
target_key=T5LayerCrossAttention)
|
||||
|
||||
return policy
|
||||
|
||||
def postprocess(self):
|
||||
|
|
|
@ -56,9 +56,6 @@ class WhisperPolicy(Policy):
|
|||
self.shard_config.enable_sequence_parallelism = False
|
||||
warnings.warn(
|
||||
"Whisper dosen't support sequence parallelism now, will ignore the sequence parallelism flag.")
|
||||
if self.shard_config.enable_jit_fused:
|
||||
self.shard_config.enable_jit_fused = False
|
||||
warnings.warn("Whisper dosen't support jit fused operator now, will ignore the jit fused flag.")
|
||||
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
policy[WhisperEncoderLayer] = ModulePolicyDescription(attribute_replacement={
|
||||
|
@ -212,6 +209,21 @@ class WhisperPolicy(Policy):
|
|||
policy=policy,
|
||||
target_key=WhisperAttention)
|
||||
|
||||
# use jit fused operator
|
||||
if self.shard_config.enable_jit_fused:
|
||||
self.append_or_create_method_replacement(description={
|
||||
'forward': get_jit_fused_whisper_decoder_layer_forward(),
|
||||
'dropout_add': get_jit_fused_dropout_add_func(),
|
||||
},
|
||||
policy=policy,
|
||||
target_key=WhisperDecoderLayer)
|
||||
self.append_or_create_method_replacement(description={
|
||||
'forward': get_jit_fused_whisper_encoder_layer_forward(),
|
||||
'dropout_add': get_jit_fused_dropout_add_func(),
|
||||
},
|
||||
policy=policy,
|
||||
target_key=WhisperEncoderLayer)
|
||||
|
||||
return policy
|
||||
|
||||
def add_lm_head_policy(self, base_policy):
|
||||
|
|
|
@ -237,6 +237,43 @@ def check_weight(org_model: Module,
|
|||
f"shard model weight {suffix} is not equal to origin model weight\n{org_weight}\n{sharded_weight}"
|
||||
|
||||
|
||||
def get_grad_tensors_for_check(org_model: Module,
|
||||
sharded_model: Module,
|
||||
layer_suffix: List[str],
|
||||
tp_group: ProcessGroup = None,
|
||||
dim: int = 0,
|
||||
atol: float = 1e-5,
|
||||
rtol: float = 1e-3,
|
||||
verbose: bool = False,
|
||||
name: str = None):
|
||||
|
||||
grad_to_check = {}
|
||||
for suffix in layer_suffix:
|
||||
org_grad = getattr_(org_model, suffix).weight.grad
|
||||
shard_grad = getattr_(sharded_model, suffix).weight.grad
|
||||
shard_weight = getattr_(sharded_model, suffix).weight
|
||||
if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight):
|
||||
shard_grad_list = [torch.zeros_like(shard_grad).to('cuda') for _ in range(dist.get_world_size(tp_group))]
|
||||
dist.all_gather(shard_grad_list, shard_grad, tp_group)
|
||||
shard_grad = torch.cat(shard_grad_list, dim=dim)
|
||||
|
||||
# embedding may be resized when using tensor parallel
|
||||
if shard_grad.shape[0] > org_grad.shape[0]:
|
||||
shard_grad = shard_grad[:org_grad.shape[0], :]
|
||||
if verbose and dist.get_rank() == 0:
|
||||
print(f"'{suffix}' grad: {org_grad}, {shard_grad}")
|
||||
|
||||
grad_to_check[suffix] = {
|
||||
"org_grad": org_grad.float(),
|
||||
"shard_grad": shard_grad.float(),
|
||||
"rtol": rtol,
|
||||
"atol": atol
|
||||
}
|
||||
|
||||
return grad_to_check
|
||||
|
||||
|
||||
# used by sam/blip2
|
||||
def check_grad(org_model: Module,
|
||||
sharded_model: Module,
|
||||
layer_suffix: List[str],
|
||||
|
@ -275,3 +312,18 @@ def unwrap_model(module: Module,
|
|||
if module.__class__.__name__ == base_model_class_name:
|
||||
return module
|
||||
return getattr(module, base_model_attribute_name, None)
|
||||
|
||||
|
||||
def check_all_grad_tensors(check_tensors):
|
||||
"""
|
||||
"org_grad": tensor to be compared from the original model
|
||||
"shard_grad": tensor to be compared from the sharded model
|
||||
"""
|
||||
for suffix, check_info in check_tensors.items():
|
||||
org_grad = check_info["org_grad"]
|
||||
shard_grad = check_info["shard_grad"]
|
||||
rtol = check_info["rtol"]
|
||||
atol = check_info["atol"]
|
||||
assert torch.allclose(
|
||||
org_grad, shard_grad, atol=atol, rtol=rtol
|
||||
), f"error attribute '{suffix}', orgin model grad is not equal to shard model grad\n{org_grad}\n{shard_grad}"
|
||||
|
|
|
@ -10,10 +10,11 @@ from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_ad
|
|||
from tests.kit.model_zoo import model_zoo
|
||||
from tests.test_shardformer.test_model._utils import (
|
||||
build_model_from_hybrid_plugin,
|
||||
check_grad,
|
||||
check_all_grad_tensors,
|
||||
check_loss,
|
||||
check_output_hidden_state,
|
||||
check_weight,
|
||||
get_grad_tensors_for_check,
|
||||
run_forward_backward_with_hybrid_plugin,
|
||||
unwrap_model,
|
||||
)
|
||||
|
@ -33,8 +34,46 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
output_transform_fn,
|
||||
criterion,
|
||||
booster)
|
||||
|
||||
stage_manager = booster.plugin.stage_manager
|
||||
tp_group = booster.plugin.tp_group
|
||||
|
||||
bert = unwrap_model(org_model, 'BertModel', 'bert')
|
||||
sharded_bert = unwrap_model(sharded_model, 'BertModel', 'bert')
|
||||
|
||||
col_layer_for_check = ['encoder.layer[0].output.dense']
|
||||
row_layer_for_check = ['embeddings.word_embeddings', 'encoder.layer[0].intermediate.dense']
|
||||
|
||||
# Save gradient tensors for comparison between the original model and the sharded model before optimizer step.
|
||||
grads_to_check = {}
|
||||
if test_config['precision'] == 'fp32':
|
||||
atol, rtol = 1e-4, 1e-3
|
||||
else:
|
||||
atol, rtol = 5e-3, 5e-3
|
||||
if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0:
|
||||
col_layer_grads = get_grad_tensors_for_check(bert,
|
||||
sharded_bert,
|
||||
col_layer_for_check,
|
||||
tp_group,
|
||||
atol=atol,
|
||||
rtol=rtol,
|
||||
dim=1,
|
||||
verbose=False)
|
||||
row_layer_grads = get_grad_tensors_for_check(bert,
|
||||
sharded_bert,
|
||||
row_layer_for_check,
|
||||
tp_group,
|
||||
atol=atol,
|
||||
rtol=rtol,
|
||||
dim=0,
|
||||
verbose=False)
|
||||
grads_to_check.update(col_layer_grads)
|
||||
grads_to_check.update(row_layer_grads)
|
||||
|
||||
# optimizer executes step
|
||||
org_optimizer.step()
|
||||
sharded_optimizer.step()
|
||||
|
||||
# check last hidden state & loss
|
||||
if stage_manager is None or stage_manager.is_last_stage():
|
||||
if test_config['precision'] == 'fp32':
|
||||
|
@ -46,23 +85,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
|
||||
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
|
||||
|
||||
bert = unwrap_model(org_model, 'BertModel', 'bert')
|
||||
sharded_bert = unwrap_model(sharded_model, 'BertModel', 'bert')
|
||||
|
||||
col_layer_for_check = ['encoder.layer[0].output.dense']
|
||||
row_layer_for_check = ['embeddings.word_embeddings', 'encoder.layer[0].intermediate.dense']
|
||||
|
||||
if test_config['precision'] == 'fp32':
|
||||
atol, rtol = 1e-4, 1e-3
|
||||
else:
|
||||
atol, rtol = 5e-3, 5e-3
|
||||
if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0:
|
||||
check_grad(bert, sharded_bert, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False)
|
||||
check_grad(bert, sharded_bert, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False)
|
||||
|
||||
# check weights after optimizer.step()
|
||||
org_optimizer.step()
|
||||
sharded_optimizer.step()
|
||||
# check weights
|
||||
if test_config['precision'] == 'fp32':
|
||||
atol, rtol = 5e-3, 1e-3
|
||||
else:
|
||||
|
@ -70,6 +93,9 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
if stage_manager is None or stage_manager.is_first_stage():
|
||||
check_weight(bert, sharded_bert, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False)
|
||||
|
||||
# check grads
|
||||
check_all_grad_tensors(grads_to_check)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
|
|
|
@ -9,10 +9,11 @@ from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_ad
|
|||
from tests.kit.model_zoo import model_zoo
|
||||
from tests.test_shardformer.test_model._utils import (
|
||||
build_model_from_hybrid_plugin,
|
||||
check_grad,
|
||||
check_all_grad_tensors,
|
||||
check_loss,
|
||||
check_output_hidden_state,
|
||||
check_weight,
|
||||
get_grad_tensors_for_check,
|
||||
run_forward_backward_with_hybrid_plugin,
|
||||
unwrap_model,
|
||||
)
|
||||
|
@ -36,6 +37,43 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
stage_manager = booster.plugin.stage_manager
|
||||
tp_group = booster.plugin.tp_group
|
||||
|
||||
# unwrap model
|
||||
bloom = unwrap_model(org_model, 'BloomModel', 'transformer')
|
||||
sharded_bloom = unwrap_model(sharded_model, 'BloomModel', 'transformer')
|
||||
|
||||
row_layer_for_check = ['h[0].self_attention.query_key_value', 'word_embeddings']
|
||||
col_layer_for_check = ['h[0].self_attention.dense']
|
||||
|
||||
# Save gradient tensors for comparison between the original model and the sharded model.
|
||||
grads_to_check = {}
|
||||
if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0:
|
||||
if test_config['precision'] == 'fp32':
|
||||
atol, rtol = 1e-6, 1e-5
|
||||
else:
|
||||
atol, rtol = 5e-3, 5e-3
|
||||
row_layer_grads = get_grad_tensors_for_check(bloom,
|
||||
sharded_bloom,
|
||||
row_layer_for_check,
|
||||
tp_group,
|
||||
atol=atol,
|
||||
rtol=rtol,
|
||||
dim=0,
|
||||
verbose=False)
|
||||
col_layer_grads = get_grad_tensors_for_check(bloom,
|
||||
sharded_bloom,
|
||||
col_layer_for_check,
|
||||
tp_group,
|
||||
atol=atol,
|
||||
rtol=rtol,
|
||||
dim=1,
|
||||
verbose=False)
|
||||
grads_to_check.update(col_layer_grads)
|
||||
grads_to_check.update(row_layer_grads)
|
||||
|
||||
# optimizer executes step
|
||||
org_optimizer.step()
|
||||
sharded_optimizer.step()
|
||||
|
||||
# check last hidden state & loss
|
||||
if stage_manager is None or stage_manager.is_last_stage():
|
||||
if test_config['precision'] == 'fp32':
|
||||
|
@ -47,24 +85,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
|
||||
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
|
||||
|
||||
# unwrap model
|
||||
bloom = unwrap_model(org_model, 'BloomModel', 'transformer')
|
||||
sharded_bloom = unwrap_model(sharded_model, 'BloomModel', 'transformer')
|
||||
|
||||
# check grad
|
||||
row_layer_for_check = ['h[0].self_attention.query_key_value', 'word_embeddings']
|
||||
col_layer_for_check = ['h[0].self_attention.dense']
|
||||
if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0:
|
||||
if test_config['precision'] == 'fp32':
|
||||
atol, rtol = 1e-6, 1e-5
|
||||
else:
|
||||
atol, rtol = 5e-3, 5e-3
|
||||
check_grad(bloom, sharded_bloom, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False)
|
||||
check_grad(bloom, sharded_bloom, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False)
|
||||
|
||||
# check weights after optimizer.step()
|
||||
org_optimizer.step()
|
||||
sharded_optimizer.step()
|
||||
if stage_manager is None or stage_manager.is_first_stage():
|
||||
if test_config['precision'] == 'fp32':
|
||||
atol, rtol = 1e-4, 1e-3
|
||||
|
@ -72,6 +92,9 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
atol, rtol = 5e-3, 5e-3
|
||||
check_weight(bloom, sharded_bloom, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False)
|
||||
|
||||
# check grads
|
||||
check_all_grad_tensors(grads_to_check)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
|
|
|
@ -9,10 +9,11 @@ from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_ad
|
|||
from tests.kit.model_zoo import model_zoo
|
||||
from tests.test_shardformer.test_model._utils import (
|
||||
build_model_from_hybrid_plugin,
|
||||
check_grad,
|
||||
check_all_grad_tensors,
|
||||
check_loss,
|
||||
check_output_hidden_state,
|
||||
check_weight,
|
||||
get_grad_tensors_for_check,
|
||||
run_forward_backward_with_hybrid_plugin,
|
||||
unwrap_model,
|
||||
)
|
||||
|
@ -36,6 +37,44 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
stage_manager = booster.plugin.stage_manager
|
||||
tp_group = booster.plugin.tp_group
|
||||
|
||||
# unwrap model
|
||||
chatglm_model = unwrap_model(org_model, 'ChatGLMModel', 'transformer')
|
||||
shard_chatglm_model = unwrap_model(sharded_model, 'ChatGLMModel', 'transformer')
|
||||
|
||||
row_layer_for_check = ['encoder.layers[0].self_attention.query_key_value', 'embedding.word_embeddings']
|
||||
col_layer_for_check = ['encoder.layers[0].self_attention.dense']
|
||||
|
||||
# Save gradient tensors for comparison between the original model and the sharded model.
|
||||
grads_to_check = {}
|
||||
if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0:
|
||||
if test_config['precision'] == 'fp32':
|
||||
atol, rtol = 1e-6, 1e-3
|
||||
else:
|
||||
atol, rtol = 5e-3, 5e-3
|
||||
row_layer_grads = get_grad_tensors_for_check(chatglm_model,
|
||||
shard_chatglm_model,
|
||||
row_layer_for_check,
|
||||
tp_group,
|
||||
atol=atol,
|
||||
rtol=rtol,
|
||||
dim=0,
|
||||
verbose=False)
|
||||
|
||||
col_layer_grads = get_grad_tensors_for_check(chatglm_model,
|
||||
shard_chatglm_model,
|
||||
col_layer_for_check,
|
||||
tp_group,
|
||||
atol=atol,
|
||||
rtol=rtol,
|
||||
dim=1,
|
||||
verbose=False)
|
||||
grads_to_check.update(col_layer_grads)
|
||||
grads_to_check.update(row_layer_grads)
|
||||
|
||||
# optimizer executes step
|
||||
org_optimizer.step()
|
||||
sharded_optimizer.step()
|
||||
|
||||
# check last hidden state & loss
|
||||
if stage_manager is None or stage_manager.is_last_stage():
|
||||
if test_config['precision'] == 'fp32':
|
||||
|
@ -48,39 +87,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
|
||||
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
|
||||
|
||||
# unwrap model
|
||||
chatglm_model = unwrap_model(org_model, 'ChatGLMModel', 'transformer')
|
||||
shard_chatglm_model = unwrap_model(sharded_model, 'ChatGLMModel', 'transformer')
|
||||
|
||||
# check grad
|
||||
row_layer_for_check = ['encoder.layers[0].self_attention.query_key_value', 'embedding.word_embeddings']
|
||||
col_layer_for_check = ['encoder.layers[0].self_attention.dense']
|
||||
if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0:
|
||||
if test_config['precision'] == 'fp32':
|
||||
atol, rtol = 1e-6, 1e-3
|
||||
else:
|
||||
atol, rtol = 5e-3, 5e-3
|
||||
check_grad(chatglm_model,
|
||||
shard_chatglm_model,
|
||||
row_layer_for_check,
|
||||
tp_group,
|
||||
atol=atol,
|
||||
rtol=rtol,
|
||||
dim=0,
|
||||
verbose=False)
|
||||
|
||||
check_grad(chatglm_model,
|
||||
shard_chatglm_model,
|
||||
col_layer_for_check,
|
||||
tp_group,
|
||||
atol=atol,
|
||||
rtol=rtol,
|
||||
dim=1,
|
||||
verbose=False)
|
||||
|
||||
# check weights after optimizer.step()
|
||||
org_optimizer.step()
|
||||
sharded_optimizer.step()
|
||||
# check weights
|
||||
if stage_manager is None or stage_manager.is_first_stage():
|
||||
if test_config['precision'] == 'fp32':
|
||||
atol, rtol = 1e-4, 1e-3
|
||||
|
@ -95,6 +102,9 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
dim=1,
|
||||
verbose=False)
|
||||
|
||||
# check grads
|
||||
check_all_grad_tensors(grads_to_check)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
|
|
|
@ -9,10 +9,11 @@ from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_ad
|
|||
from tests.kit.model_zoo import model_zoo
|
||||
from tests.test_shardformer.test_model._utils import (
|
||||
build_model_from_hybrid_plugin,
|
||||
check_grad,
|
||||
check_all_grad_tensors,
|
||||
check_loss,
|
||||
check_output_hidden_state,
|
||||
check_weight,
|
||||
get_grad_tensors_for_check,
|
||||
run_forward_backward_with_hybrid_plugin,
|
||||
unwrap_model,
|
||||
)
|
||||
|
@ -36,6 +37,43 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
stage_manager = booster.plugin.stage_manager
|
||||
tp_group = booster.plugin.tp_group
|
||||
|
||||
# unwrap model
|
||||
gpt2 = unwrap_model(org_model, 'GPT2Model', 'transformer')
|
||||
sharded_gpt2 = unwrap_model(sharded_model, 'GPT2Model', 'transformer')
|
||||
|
||||
col_layer_for_check = ['h[0].mlp.c_fc']
|
||||
row_layer_for_check = ['wte', 'h[0].mlp.c_proj']
|
||||
|
||||
# Save gradient tensors for comparison between the original model and the sharded model.
|
||||
grads_to_check = {}
|
||||
if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0:
|
||||
if test_config['precision'] == 'fp32':
|
||||
atol, rtol = 1e-4, 1e-3
|
||||
else:
|
||||
atol, rtol = 5e-3, 5e-3
|
||||
col_layer_grads = get_grad_tensors_for_check(gpt2,
|
||||
sharded_gpt2,
|
||||
col_layer_for_check,
|
||||
tp_group,
|
||||
atol=atol,
|
||||
rtol=rtol,
|
||||
dim=1,
|
||||
verbose=False)
|
||||
row_layer_grads = get_grad_tensors_for_check(gpt2,
|
||||
sharded_gpt2,
|
||||
row_layer_for_check,
|
||||
tp_group,
|
||||
atol=atol,
|
||||
rtol=rtol,
|
||||
dim=0,
|
||||
verbose=False)
|
||||
grads_to_check.update(col_layer_grads)
|
||||
grads_to_check.update(row_layer_grads)
|
||||
|
||||
# optimizer executes step
|
||||
org_optimizer.step()
|
||||
sharded_optimizer.step()
|
||||
|
||||
# check last hidden state & loss
|
||||
if stage_manager is None or stage_manager.is_last_stage():
|
||||
if test_config['precision'] == 'fp32':
|
||||
|
@ -48,25 +86,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
|
||||
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
|
||||
|
||||
# unwrap model
|
||||
gpt2 = unwrap_model(org_model, 'GPT2Model', 'transformer')
|
||||
sharded_gpt2 = unwrap_model(sharded_model, 'GPT2Model', 'transformer')
|
||||
|
||||
col_layer_for_check = ['h[0].mlp.c_fc']
|
||||
row_layer_for_check = ['wte', 'h[0].mlp.c_proj']
|
||||
|
||||
# check grad
|
||||
if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0:
|
||||
if test_config['precision'] == 'fp32':
|
||||
atol, rtol = 1e-4, 1e-3
|
||||
else:
|
||||
atol, rtol = 5e-3, 5e-3
|
||||
check_grad(gpt2, sharded_gpt2, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False)
|
||||
check_grad(gpt2, sharded_gpt2, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False)
|
||||
|
||||
# check weights after optimizer.step()
|
||||
org_optimizer.step()
|
||||
sharded_optimizer.step()
|
||||
# check weights
|
||||
if stage_manager is None or stage_manager.is_first_stage():
|
||||
if test_config['precision'] == 'fp32':
|
||||
atol, rtol = 5e-3, 1e-3
|
||||
|
@ -74,6 +94,9 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
atol, rtol = 5e-3, 5e-3
|
||||
check_weight(gpt2, sharded_gpt2, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False)
|
||||
|
||||
# check grads
|
||||
check_all_grad_tensors(grads_to_check)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
|
|
|
@ -12,10 +12,11 @@ from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_ad
|
|||
from tests.kit.model_zoo import model_zoo
|
||||
from tests.test_shardformer.test_model._utils import (
|
||||
build_model_from_hybrid_plugin,
|
||||
check_grad,
|
||||
check_all_grad_tensors,
|
||||
check_loss,
|
||||
check_output_hidden_state,
|
||||
check_weight,
|
||||
get_grad_tensors_for_check,
|
||||
run_forward_backward_with_hybrid_plugin,
|
||||
unwrap_model,
|
||||
)
|
||||
|
@ -41,6 +42,43 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
stage_manager = booster.plugin.stage_manager
|
||||
tp_group = booster.plugin.tp_group
|
||||
|
||||
# unwrap model
|
||||
llama_model = unwrap_model(org_model, 'LlamaModel', 'model')
|
||||
shard_llama_model = unwrap_model(sharded_model, 'LlamaModel', 'model')
|
||||
|
||||
row_layer_for_check = ['layers[0].self_attn.q_proj', 'embed_tokens']
|
||||
col_layer_for_check = ['layers[0].self_attn.o_proj']
|
||||
|
||||
# Save gradient tensors for comparison between the original model and the sharded model before optimizer step.
|
||||
grads_to_check = {}
|
||||
if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0:
|
||||
if test_config['precision'] == 'fp32':
|
||||
atol, rtol = 1e-6, 1e-4
|
||||
else:
|
||||
atol, rtol = 5e-3, 5e-3
|
||||
row_layer_grads = get_grad_tensors_for_check(llama_model,
|
||||
shard_llama_model,
|
||||
row_layer_for_check,
|
||||
tp_group,
|
||||
atol=atol,
|
||||
rtol=rtol,
|
||||
dim=0,
|
||||
verbose=False)
|
||||
col_layer_grads = get_grad_tensors_for_check(llama_model,
|
||||
shard_llama_model,
|
||||
col_layer_for_check,
|
||||
tp_group,
|
||||
atol=atol,
|
||||
rtol=rtol,
|
||||
dim=1,
|
||||
verbose=False)
|
||||
grads_to_check.update(col_layer_grads)
|
||||
grads_to_check.update(row_layer_grads)
|
||||
|
||||
# optimizer executes step
|
||||
org_optimizer.step()
|
||||
sharded_optimizer.step()
|
||||
|
||||
# check last hidden state & loss
|
||||
if stage_manager is None or stage_manager.is_last_stage():
|
||||
if test_config['precision'] == 'fp32':
|
||||
|
@ -53,37 +91,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
|
||||
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
|
||||
|
||||
# unwrap model
|
||||
llama_model = unwrap_model(org_model, 'LlamaModel', 'model')
|
||||
shard_llama_model = unwrap_model(sharded_model, 'LlamaModel', 'model')
|
||||
# check grad
|
||||
row_layer_for_check = ['layers[0].self_attn.q_proj', 'embed_tokens']
|
||||
col_layer_for_check = ['layers[0].self_attn.o_proj']
|
||||
if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0:
|
||||
if test_config['precision'] == 'fp32':
|
||||
atol, rtol = 1e-6, 1e-4
|
||||
else:
|
||||
atol, rtol = 5e-3, 5e-3
|
||||
check_grad(llama_model,
|
||||
shard_llama_model,
|
||||
row_layer_for_check,
|
||||
tp_group,
|
||||
atol=atol,
|
||||
rtol=rtol,
|
||||
dim=0,
|
||||
verbose=False)
|
||||
check_grad(llama_model,
|
||||
shard_llama_model,
|
||||
col_layer_for_check,
|
||||
tp_group,
|
||||
atol=atol,
|
||||
rtol=rtol,
|
||||
dim=1,
|
||||
verbose=False)
|
||||
|
||||
# check weights after optimizer.step()
|
||||
org_optimizer.step()
|
||||
sharded_optimizer.step()
|
||||
# check weights
|
||||
if stage_manager is None or stage_manager.is_first_stage():
|
||||
if test_config['precision'] == 'fp32':
|
||||
atol, rtol = 1e-4, 1e-3
|
||||
|
@ -98,6 +106,9 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
dim=1,
|
||||
verbose=False)
|
||||
|
||||
# check grads
|
||||
check_all_grad_tensors(grads_to_check)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
|
|
|
@ -11,10 +11,11 @@ from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_ad
|
|||
from tests.kit.model_zoo import model_zoo
|
||||
from tests.test_shardformer.test_model._utils import (
|
||||
build_model_from_hybrid_plugin,
|
||||
check_grad,
|
||||
check_all_grad_tensors,
|
||||
check_loss,
|
||||
check_output_hidden_state,
|
||||
check_weight,
|
||||
get_grad_tensors_for_check,
|
||||
run_forward_backward_with_hybrid_plugin,
|
||||
unwrap_model,
|
||||
)
|
||||
|
@ -40,6 +41,43 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
stage_manager = booster.plugin.stage_manager
|
||||
tp_group = booster.plugin.tp_group
|
||||
|
||||
# unwrap model
|
||||
opt_model = unwrap_model(org_model, 'OPTModel', 'model')
|
||||
shard_opt_model = unwrap_model(sharded_model, 'OPTModel', 'model')
|
||||
|
||||
row_layer_for_check = ['decoder.layers[0].self_attn.q_proj', 'decoder.embed_tokens'] # 'decoder.embed_tokens'
|
||||
col_layer_for_check = ['decoder.layers[0].self_attn.out_proj']
|
||||
|
||||
# Save gradient tensors for comparison between the original model and the sharded model.
|
||||
grads_to_check = {}
|
||||
if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0:
|
||||
if test_config['precision'] == 'fp32':
|
||||
atol, rtol = 1e-6, 1e-3
|
||||
else:
|
||||
atol, rtol = 4e-2, 4e-2
|
||||
row_layer_grads = get_grad_tensors_for_check(opt_model,
|
||||
shard_opt_model,
|
||||
row_layer_for_check,
|
||||
tp_group,
|
||||
atol=atol,
|
||||
rtol=rtol,
|
||||
dim=0,
|
||||
verbose=False)
|
||||
col_layer_grads = get_grad_tensors_for_check(opt_model,
|
||||
shard_opt_model,
|
||||
col_layer_for_check,
|
||||
tp_group,
|
||||
atol=atol,
|
||||
rtol=rtol,
|
||||
dim=1,
|
||||
verbose=False)
|
||||
grads_to_check.update(col_layer_grads)
|
||||
grads_to_check.update(row_layer_grads)
|
||||
|
||||
# optimizer executes step
|
||||
org_optimizer.step()
|
||||
sharded_optimizer.step()
|
||||
|
||||
# check last hidden state & loss
|
||||
if stage_manager is None or stage_manager.is_last_stage():
|
||||
if test_config['precision'] == 'fp32':
|
||||
|
@ -51,38 +89,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
|
||||
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
|
||||
|
||||
# unwrap model
|
||||
opt_model = unwrap_model(org_model, 'OPTModel', 'model')
|
||||
shard_opt_model = unwrap_model(sharded_model, 'OPTModel', 'model')
|
||||
|
||||
# check grad
|
||||
row_layer_for_check = ['decoder.layers[0].self_attn.q_proj', 'decoder.embed_tokens'] # 'decoder.embed_tokens'
|
||||
col_layer_for_check = ['decoder.layers[0].self_attn.out_proj']
|
||||
if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0:
|
||||
if test_config['precision'] == 'fp32':
|
||||
atol, rtol = 1e-6, 1e-3
|
||||
else:
|
||||
atol, rtol = 3e-2, 3e-2
|
||||
check_grad(opt_model,
|
||||
shard_opt_model,
|
||||
row_layer_for_check,
|
||||
tp_group,
|
||||
atol=atol,
|
||||
rtol=rtol,
|
||||
dim=0,
|
||||
verbose=False)
|
||||
check_grad(opt_model,
|
||||
shard_opt_model,
|
||||
col_layer_for_check,
|
||||
tp_group,
|
||||
atol=atol,
|
||||
rtol=rtol,
|
||||
dim=1,
|
||||
verbose=False)
|
||||
|
||||
# check weights after optimizer.step()
|
||||
org_optimizer.step()
|
||||
sharded_optimizer.step()
|
||||
# check weights
|
||||
if stage_manager is None or stage_manager.is_first_stage():
|
||||
if test_config['precision'] == 'fp32':
|
||||
atol, rtol = 1e-3, 1e-3
|
||||
|
@ -97,6 +104,9 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
dim=1,
|
||||
verbose=False)
|
||||
|
||||
# check grads
|
||||
check_all_grad_tensors(grads_to_check)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
|
|
|
@ -10,10 +10,11 @@ from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_ad
|
|||
from tests.kit.model_zoo import model_zoo
|
||||
from tests.test_shardformer.test_model._utils import (
|
||||
build_model_from_hybrid_plugin,
|
||||
check_grad,
|
||||
check_all_grad_tensors,
|
||||
check_loss,
|
||||
check_output_hidden_state,
|
||||
check_weight,
|
||||
get_grad_tensors_for_check,
|
||||
run_forward_backward_with_hybrid_plugin,
|
||||
unwrap_model,
|
||||
)
|
||||
|
@ -37,6 +38,32 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
stage_manager = booster.plugin.stage_manager
|
||||
tp_group = booster.plugin.tp_group
|
||||
|
||||
# unwrap model
|
||||
t5 = unwrap_model(org_model)
|
||||
sharded_t5 = unwrap_model(sharded_model)
|
||||
|
||||
row_layer_for_check = ['shared', 'encoder.block[0].layer[0].SelfAttention.q']
|
||||
|
||||
# Save gradient tensors for comparison between the original model and the sharded model before optimizer step.
|
||||
grads_to_check = {}
|
||||
if test_config['precision'] == 'fp32':
|
||||
atol, rtol = 1e-5, 1e-3
|
||||
else:
|
||||
atol, rtol = 5e-3, 5e-3
|
||||
if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0:
|
||||
row_layer_grads = get_grad_tensors_for_check(t5,
|
||||
sharded_t5,
|
||||
row_layer_for_check,
|
||||
tp_group,
|
||||
atol=atol,
|
||||
rtol=rtol,
|
||||
dim=0)
|
||||
grads_to_check.update(row_layer_grads)
|
||||
|
||||
# optimizer executes step
|
||||
org_optimizer.step()
|
||||
sharded_optimizer.step()
|
||||
|
||||
# check last hidden state & loss
|
||||
if stage_manager is None or stage_manager.is_last_stage():
|
||||
if test_config['precision'] == 'fp32':
|
||||
|
@ -49,30 +76,17 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
|
||||
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
|
||||
|
||||
# unwrap model
|
||||
t5 = unwrap_model(org_model)
|
||||
sharded_t5 = unwrap_model(sharded_model)
|
||||
|
||||
row_layer_for_check = ['shared', 'encoder.block[0].layer[0].SelfAttention.q']
|
||||
|
||||
# check grad
|
||||
# check weights
|
||||
if test_config['precision'] == 'fp32':
|
||||
atol, rtol = 1e-5, 1e-3
|
||||
else:
|
||||
atol, rtol = 5e-3, 5e-3
|
||||
if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0:
|
||||
check_grad(t5, sharded_t5, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0)
|
||||
|
||||
# check weights after optimizer.step()
|
||||
org_optimizer.step()
|
||||
sharded_optimizer.step()
|
||||
if test_config['precision'] == 'fp32':
|
||||
atol, rtol = 1e-4, 1e-3
|
||||
atol, rtol = 5e-4, 1e-3
|
||||
else:
|
||||
atol, rtol = 5e-3, 5e-3
|
||||
if stage_manager is None or stage_manager.is_first_stage():
|
||||
check_weight(t5, sharded_t5, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False)
|
||||
|
||||
# check grads
|
||||
check_all_grad_tensors(grads_to_check)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
|
|
|
@ -9,10 +9,11 @@ from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_ad
|
|||
from tests.kit.model_zoo import model_zoo
|
||||
from tests.test_shardformer.test_model._utils import (
|
||||
build_model_from_hybrid_plugin,
|
||||
check_grad,
|
||||
check_all_grad_tensors,
|
||||
check_loss,
|
||||
check_output_hidden_state,
|
||||
check_weight,
|
||||
get_grad_tensors_for_check,
|
||||
run_forward_backward_with_hybrid_plugin,
|
||||
unwrap_model,
|
||||
)
|
||||
|
@ -36,6 +37,44 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
stage_manager = booster.plugin.stage_manager
|
||||
tp_group = booster.plugin.tp_group
|
||||
|
||||
# unwrap model
|
||||
vit_model = unwrap_model(org_model, 'ViTModel', 'vit')
|
||||
shard_vit_model = unwrap_model(sharded_model, 'ViTModel', 'vit')
|
||||
|
||||
# check grad
|
||||
row_layer_for_check = ['encoder.layer[0].attention.attention.query', 'embeddings.patch_embeddings.projection']
|
||||
col_layer_for_check = ['encoder.layer[0].attention.output.dense']
|
||||
|
||||
# Save gradient tensors for comparison between the original model and the sharded model before optimizer step.
|
||||
grads_to_check = {}
|
||||
if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0:
|
||||
if test_config['precision'] == 'fp32':
|
||||
atol, rtol = 1e-5, 1e-3
|
||||
else:
|
||||
atol, rtol = 5e-3, 5e-3
|
||||
row_layer_grads = get_grad_tensors_for_check(vit_model,
|
||||
shard_vit_model,
|
||||
row_layer_for_check,
|
||||
tp_group,
|
||||
atol=atol,
|
||||
rtol=rtol,
|
||||
dim=0,
|
||||
verbose=False)
|
||||
col_layer_grads = get_grad_tensors_for_check(vit_model,
|
||||
shard_vit_model,
|
||||
col_layer_for_check,
|
||||
tp_group,
|
||||
atol=atol,
|
||||
rtol=rtol,
|
||||
dim=1,
|
||||
verbose=False)
|
||||
grads_to_check.update(col_layer_grads)
|
||||
grads_to_check.update(row_layer_grads)
|
||||
|
||||
# optimizer executes step
|
||||
org_optimizer.step()
|
||||
sharded_optimizer.step()
|
||||
|
||||
# check last hidden state & loss
|
||||
if stage_manager is None or stage_manager.is_last_stage():
|
||||
if test_config['precision'] == 'fp32':
|
||||
|
@ -47,38 +86,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol)
|
||||
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
|
||||
|
||||
# unwrap model
|
||||
vit_model = unwrap_model(org_model, 'ViTModel', 'vit')
|
||||
shard_vit_model = unwrap_model(sharded_model, 'ViTModel', 'vit')
|
||||
|
||||
# check grad
|
||||
row_layer_for_check = ['encoder.layer[0].attention.attention.query', 'embeddings.patch_embeddings.projection']
|
||||
col_layer_for_check = ['encoder.layer[0].attention.output.dense']
|
||||
if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0:
|
||||
if test_config['precision'] == 'fp32':
|
||||
atol, rtol = 1e-5, 1e-3
|
||||
else:
|
||||
atol, rtol = 5e-3, 5e-3
|
||||
check_grad(vit_model,
|
||||
shard_vit_model,
|
||||
row_layer_for_check,
|
||||
tp_group,
|
||||
atol=atol,
|
||||
rtol=rtol,
|
||||
dim=0,
|
||||
verbose=False)
|
||||
check_grad(vit_model,
|
||||
shard_vit_model,
|
||||
col_layer_for_check,
|
||||
tp_group,
|
||||
atol=atol,
|
||||
rtol=rtol,
|
||||
dim=1,
|
||||
verbose=False)
|
||||
|
||||
# check weights after optimizer.step()
|
||||
org_optimizer.step()
|
||||
sharded_optimizer.step()
|
||||
# check weights
|
||||
if stage_manager is None or stage_manager.is_first_stage():
|
||||
if test_config['precision'] == 'fp32':
|
||||
atol, rtol = 5e-3, 1e-3
|
||||
|
@ -93,6 +101,9 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
dim=1,
|
||||
verbose=False)
|
||||
|
||||
# check grads
|
||||
check_all_grad_tensors(grads_to_check)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
|
|
|
@ -15,10 +15,11 @@ from colossalai.testing import (
|
|||
from tests.kit.model_zoo import model_zoo
|
||||
from tests.test_shardformer.test_model._utils import (
|
||||
build_model_from_hybrid_plugin,
|
||||
check_grad,
|
||||
check_all_grad_tensors,
|
||||
check_loss,
|
||||
check_output_hidden_state,
|
||||
check_weight,
|
||||
get_grad_tensors_for_check,
|
||||
run_forward_backward_with_hybrid_plugin,
|
||||
)
|
||||
|
||||
|
@ -41,18 +42,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
stage_manager = booster.plugin.stage_manager
|
||||
tp_group = booster.plugin.tp_group
|
||||
|
||||
# check last hidden state & loss
|
||||
if stage_manager is None or stage_manager.is_last_stage():
|
||||
if test_config['precision'] == 'fp32':
|
||||
atol, rtol = 2e-4, 2e-4
|
||||
else:
|
||||
atol, rtol = 5e-3, 5e-3
|
||||
|
||||
if org_model.__class__.__name__ == 'WhisperModel':
|
||||
check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol)
|
||||
|
||||
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
|
||||
|
||||
# unwarp the model
|
||||
if org_model.__class__.__name__ == 'WhisperForConditionalGeneration':
|
||||
whisper = org_model.model
|
||||
|
@ -75,19 +64,48 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
#'decoder.layers[0].self_attn.out_proj'
|
||||
]
|
||||
|
||||
# check weights and gradients
|
||||
# Save gradient tensors for comparison between the original model and the sharded model before optimizer step.
|
||||
grads_to_check = {}
|
||||
if test_config['precision'] == 'fp32':
|
||||
atol, rtol = 2e-4, 2e-4
|
||||
else:
|
||||
atol, rtol = 5e-3, 5e-3
|
||||
|
||||
if stage_manager is None or stage_manager.is_first_stage():
|
||||
check_grad(whisper, sharded_whisper, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1)
|
||||
check_grad(whisper, sharded_whisper, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0)
|
||||
row_layer_grads = get_grad_tensors_for_check(whisper,
|
||||
sharded_whisper,
|
||||
row_layer_for_check,
|
||||
tp_group,
|
||||
atol=atol,
|
||||
rtol=rtol,
|
||||
dim=1)
|
||||
col_layer_grads = get_grad_tensors_for_check(whisper,
|
||||
sharded_whisper,
|
||||
col_layer_for_check,
|
||||
tp_group,
|
||||
atol=atol,
|
||||
rtol=rtol,
|
||||
dim=0)
|
||||
grads_to_check.update(col_layer_grads)
|
||||
grads_to_check.update(row_layer_grads)
|
||||
|
||||
# check weights after optimizer.step()
|
||||
# optimizer executes step
|
||||
org_optimizer.step()
|
||||
sharded_optimizer.step()
|
||||
|
||||
# check last hidden state & loss
|
||||
if stage_manager is None or stage_manager.is_last_stage():
|
||||
if test_config['precision'] == 'fp32':
|
||||
atol, rtol = 2e-4, 2e-4
|
||||
else:
|
||||
atol, rtol = 5e-3, 5e-3
|
||||
|
||||
if org_model.__class__.__name__ == 'WhisperModel':
|
||||
check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol)
|
||||
|
||||
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
|
||||
|
||||
# check weights
|
||||
if test_config['precision'] == 'fp32':
|
||||
atol, rtol = 5e-4, 5e-4
|
||||
else:
|
||||
|
@ -110,8 +128,12 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
dim=0,
|
||||
verbose=False)
|
||||
|
||||
# check grads
|
||||
check_all_grad_tensors(grads_to_check)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
#TODO fix WhisperForConditionalGeneration enable jit fused operato
|
||||
# TODO(jianghai) fix fp16
|
||||
@parameterize(
|
||||
|
|
Loading…
Reference in New Issue