diff --git a/tests/test_tensor/model/test_gpt2.py b/tests/test_tensor/model/test_gpt2.py index b74016451..6f2ef9fa8 100644 --- a/tests/test_tensor/model/test_gpt2.py +++ b/tests/test_tensor/model/test_gpt2.py @@ -56,7 +56,7 @@ def init_megatron_spec(model, pg: ProcessGroup): elif 'wte' in mn or 'wpe' in mn: assert 'weight' in pn split_param_col_tp1d(param, pg) - elif 'c_fc' in mn or 'c_proj' in mn: + elif 'c_attn' in mn or 'c_proj' in mn: split_param_col_tp1d(param, pg) # debug_print([0], '\t', param.compute_spec, param.shape)