Browse Source

[hotfix] fix megatron_init in test_gpt2.py (#1357)

pull/1358/head
HELSON 2 years ago committed by GitHub
parent
commit
7a065dc9f6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 2
      tests/test_tensor/model/test_gpt2.py

2
tests/test_tensor/model/test_gpt2.py

@ -56,7 +56,7 @@ def init_megatron_spec(model, pg: ProcessGroup):
elif 'wte' in mn or 'wpe' in mn:
assert 'weight' in pn
split_param_col_tp1d(param, pg)
elif 'c_fc' in mn or 'c_proj' in mn:
elif 'c_attn' in mn or 'c_proj' in mn:
split_param_col_tp1d(param, pg)
# debug_print([0], '\t', param.compute_spec, param.shape)

Loading…
Cancel
Save