mirror of https://github.com/hpcaitech/ColossalAI
[unit test] add megatron init test in zero_optim (#1358)
parent
7a065dc9f6
commit
4417804129
|
@ -18,6 +18,7 @@ from colossalai.testing import parameterize
|
|||
from colossalai.amp import convert_to_apex_amp
|
||||
from colossalai.gemini.gemini_mgr import GeminiManager
|
||||
from colossalai.tensor import ColoTensorSpec, ShardSpec, ComputePattern, ComputeSpec, ProcessGroup, ColoTensor
|
||||
from tests.test_tensor.model.test_gpt2 import init_megatron_spec
|
||||
|
||||
|
||||
def check_param_equal(model, torch_model, pg: ProcessGroup):
|
||||
|
@ -127,10 +128,10 @@ def run_dist(rank, world_size, port):
|
|||
config = {}
|
||||
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
if world_size == 4:
|
||||
run_gpt(tp_init_spec_func=init_1d_col_spec)
|
||||
run_gpt(tp_init_spec_func=init_1d_row_spec)
|
||||
run_gpt(tp_init_spec_func=init_megatron_spec)
|
||||
else:
|
||||
run_gpt(tp_init_spec_func=init_1d_col_spec)
|
||||
run_gpt(tp_init_spec_func=init_1d_row_spec)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
|
|
Loading…
Reference in New Issue