You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ColossalAI/tests/test_pipeline/test_cuda_rpc_chimera.py

81 lines
2.4 KiB

import torch
from torch import nn
import torch.autograd as autograd
from colossalai.pipeline.rpc import ChimeraPipelineEngine
from colossalai.testing import assert_close
from rpc_test_utils import rpc_run, parse_args, RpcTestModel
# global variable for model created
feat_num = 100
h = 100
def partition(pp_rank: int, chunk: int, stage_num: int):
torch.manual_seed(1024)
partition = RpcTestModel(pp_rank, stage_num, feat_num, h)
return partition
def run_master(args):
torch.manual_seed(100)
epoch = args.epoch
device = args.device
stage_num = args.world_size
chunk = 1
num_microbatches = args.num_microbatches
use_checkpoint = False
sample_num = 1024
batch_size = 1024
assert sample_num % batch_size == 0
engine = ChimeraPipelineEngine(partition_fn=partition,
stage_num=stage_num,
num_microbatches=num_microbatches,
device=device,
checkpoint=use_checkpoint)
engine.initialize_optimizer(torch.optim.Adam, lr=1e-3)
input_sample = torch.randn((sample_num, feat_num), device=device)
forward_result = engine.forward_backward(input_sample)
cuda_rpc_result = []
single_result = []
actual_stage_num = engine._get_actual_stage_num()
# compute forward result and backward grad of parameters in cuda rpc
cuda_rpc_result.append(sum(forward_result[0]))
grad = engine.remote_grad()
for stage_id in range(actual_stage_num):
for p in grad[stage_id]:
cuda_rpc_result.append(p)
# compute forward result and backward grad of parameters just in rank_0
test_model = nn.Sequential(
*[partition(pp_rank, chunk, actual_stage_num) for pp_rank in range(actual_stage_num)]).to(device)
# input_sample = input_sample[len(input_sample) // 2:]
input_sample = input_sample.requires_grad_()
out_val = test_model(input_sample).sum()
autograd.backward(out_val)
single_result.append(out_val)
for p in test_model.parameters():
single_result.append(p.grad)
# print("my")
# print(cuda_rpc_result[1])
# print("answer:")
# print(single_result[1])
# assert len(cuda_rpc_result) == len(single_result)
# for r_c, r_s in zip(cuda_rpc_result, single_result):
# assert_close(r_c, r_s, 0.001, 0.001)
if __name__ == "__main__":
args = parse_args()
rpc_run(args, run_master)