ColossalAI/tests/test_zero_data_parallel/test_shard_model_v2.py

58 lines
1.6 KiB
Python
Raw Normal View History

#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import copy
from functools import partial
from operator import mod
from pyexpat import model
import colossalai
import pytest
import torch
import torch.multiprocessing as mp
from colossalai.logging import disable_existing_loggers
from colossalai.utils import free_port
from colossalai.zero.sharded_model import ShardedModelV2
from colossalai.core import global_context as gpc
from colossalai.context.parallel_mode import ParallelMode
from tests.test_zero_data_parallel.common import Net, CONFIG, check_grads
def run_fwd_bwd(model, x, enable_autocast=False):
model.train()
with torch.cuda.amp.autocast(enabled=enable_autocast):
y = model(x)
loss = y.sum()
loss = loss.float()
loss.backward()
def run_dist(rank, world_size, port):
colossalai.launch(config=CONFIG,
rank=rank,
world_size=world_size,
host='localhost',
port=port,
backend='nccl')
model = Net(checkpoint=True).cuda()
zero_model = copy.deepcopy(model)
zero_model = ShardedModelV2(zero_model, process_group=gpc.get_group(ParallelMode.DATA))
for _ in range(2):
x = torch.rand(2, 5).cuda()
run_fwd_bwd(zero_model, x, False)
run_fwd_bwd(model, x, False)
check_grads(model, zero_model)
@pytest.mark.dist
def test_shard_model_v2():
world_size = 2
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_shard_model_v2()