mirror of https://github.com/hpcaitech/ColossalAI
[pipeline/rpc] implement distributed optimizer | test with assert_close (#1486)
* support p2p communication with any type of object | pass test * reconstruct pipeline schedule with p2p_v2.py(support communication with List[Any]) | pass test * [engin/schedule] use p2p_v2 to recontruct pipeline_schedule * [pipeline/rpc] implement a demo for PP with cuda rpc framework * [pipeline/rpc] support interleaving | fix checkpoint bug | change logic when dispatch data in work_list to ensure steady 1F1B * [pipeline/rpc] implement distributed optimizer | test with assert_close * [pipeline/rpc] implement distributed optimizer | test with assert_closepull/1493/head
parent
3da68d6b1b
commit
9145aef2b4
|
@ -9,6 +9,7 @@ import torch.distributed.rpc as rpc
|
|||
from torch.futures import Future
|
||||
from torch._C._distributed_rpc import PyRRef
|
||||
from torch import autograd
|
||||
from torch import optim
|
||||
from tqdm import tqdm
|
||||
|
||||
from colorama import Back, Style
|
||||
|
@ -43,8 +44,7 @@ def tensor_shape_list(tensors):
|
|||
class Phase(Enum):
|
||||
FORWARD = 0
|
||||
BACKWARD = 1
|
||||
ACCUM_GRAD = 2
|
||||
SYNC = 3
|
||||
UPDATE = 2
|
||||
|
||||
|
||||
class UniqueKey:
|
||||
|
@ -440,8 +440,6 @@ class Worker:
|
|||
if isinstance(input_node, torch.Tensor):
|
||||
consume_result.append(input_node.grad)
|
||||
|
||||
elif phase == Phase.SYNC:
|
||||
pass
|
||||
else:
|
||||
raise TypeError(f"Unknown phase appears in _consume_work_item_by_phase {phase}")
|
||||
|
||||
|
@ -478,6 +476,18 @@ class Worker:
|
|||
'work loop', 'green')
|
||||
work_item.output.set_result(consume_result)
|
||||
|
||||
def initialize_optimizer(self, optimizer_class: type, **kwargs):
|
||||
self.optimizer: optim.Optimizer = optimizer_class(self.module_partition.parameters(), **kwargs)
|
||||
|
||||
def step(self):
|
||||
assert hasattr(self, "optimizer"), "call initialize_optimizer first before you call step!"
|
||||
self.work_list.clear()
|
||||
self.output_list.clear()
|
||||
self.microbatch_id_to_backward_cache.clear()
|
||||
|
||||
self.optimizer.step()
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
|
||||
# TODO
|
||||
# 1. chunk
|
||||
|
@ -617,6 +627,24 @@ class PipelineEngineBase(ABC, nn.Module):
|
|||
first_stage_worker.rpc_sync().get_output_by_key(key)
|
||||
return forward_result
|
||||
|
||||
def initialize_optimizer(self, optimizer_class: type, **kwargs):
|
||||
actual_stage_num = self._get_actual_stage_num()
|
||||
for pp_rank in range(actual_stage_num):
|
||||
worker_rref = self.pp_rank_to_worker_rref[pp_rank]
|
||||
worker_rref.rpc_sync().initialize_optimizer(optimizer_class, **kwargs)
|
||||
|
||||
def step(self):
|
||||
step_futs: List[Future] = []
|
||||
actual_stage_num = self._get_actual_stage_num()
|
||||
for pp_rank in range(actual_stage_num):
|
||||
worker_rref = self.pp_rank_to_worker_rref[pp_rank]
|
||||
fut = worker_rref.rpc_async().step()
|
||||
step_futs.append(fut)
|
||||
|
||||
# wait for all optimizers
|
||||
for fut in step_futs:
|
||||
fut.wait()
|
||||
|
||||
|
||||
class FillDrainPipelineEngine(PipelineEngineBase):
|
||||
|
||||
|
|
|
@ -0,0 +1,84 @@
|
|||
import os
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.multiprocessing as mp
|
||||
import torch.distributed.rpc as rpc
|
||||
from torch import autograd
|
||||
from torch.optim import SGD, Adam, RMSprop, Optimizer
|
||||
from colorama import Back, Style
|
||||
|
||||
from colossalai.pipeline.rpc.PipelineBase import FillDrainPipelineEngine, OneFOneBPipelineEngine
|
||||
from colossalai.testing import assert_close
|
||||
|
||||
|
||||
def color_debug(text, prefix=' ', color='blue'):
|
||||
color = color.upper()
|
||||
print(getattr(Back, color), prefix, Style.RESET_ALL, text)
|
||||
|
||||
|
||||
class RpcTestModel(nn.Module):
|
||||
|
||||
def __init__(self, stage_id, actual_stage_num, feat_num, h) -> None:
|
||||
super().__init__()
|
||||
self.rank = stage_id
|
||||
self.is_last_rank = stage_id == actual_stage_num - 1
|
||||
self.linear_name = f'linear_{stage_id}'
|
||||
if stage_id == 0:
|
||||
setattr(self, self.linear_name, nn.Linear(feat_num, h))
|
||||
elif stage_id == actual_stage_num - 1:
|
||||
setattr(self, self.linear_name, nn.Linear(h, 1))
|
||||
else:
|
||||
setattr(self, self.linear_name, nn.Linear(h, h))
|
||||
|
||||
def forward(self, x) -> torch.Tensor:
|
||||
linear: nn.Module = getattr(self, self.linear_name)
|
||||
out: torch.Tensor = linear(x)
|
||||
|
||||
if self.is_last_rank:
|
||||
out = out.sum()
|
||||
return out
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--world_size', type=int, default=2)
|
||||
parser.add_argument('--num_microbatches', type=int, default=2)
|
||||
parser.add_argument('--chunk', type=int, default=1)
|
||||
parser.add_argument('--use_checkpoint', action='store_true')
|
||||
parser.add_argument('--use_interleave', action='store_true')
|
||||
parser.add_argument('--optimizer', type=str, choices=['SGD', 'Adam', 'RMSprop'], default='SGD')
|
||||
parser.add_argument('--device', type=str, default='cuda')
|
||||
parser.add_argument('--master_addr', type=str, default='localhost')
|
||||
parser.add_argument('--master_port', type=str, default='29020')
|
||||
parser.add_argument('--num_worker_threads', type=str, default=128)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def run_worker(rank, args, master_func):
|
||||
os.environ['MASTER_ADDR'] = args.master_addr
|
||||
os.environ['MASTER_PORT'] = args.master_port
|
||||
|
||||
# config rpc
|
||||
# if cuda is used, set_device_map is a must is configured
|
||||
# for cuda is not supported in torch rpc by default
|
||||
options = rpc.TensorPipeRpcBackendOptions(num_worker_threads=args.num_worker_threads)
|
||||
|
||||
world_size = args.world_size
|
||||
for rank_idx in range(world_size):
|
||||
options.set_device_map(f'work{rank_idx}', {rank: rank_idx})
|
||||
|
||||
rpc.init_rpc(name=f'work{rank}', rank=rank, world_size=world_size, rpc_backend_options=options)
|
||||
|
||||
# in rpc mode, only rank 0 is needed to be coded
|
||||
if rank == 0:
|
||||
master_func(args)
|
||||
# barrier here
|
||||
rpc.shutdown()
|
||||
|
||||
|
||||
def rpc_run(args, master_func):
|
||||
world_size = args.world_size
|
||||
assert args.num_microbatches >= args.world_size, "num_microbatches cannot be fewer than world_size!"
|
||||
mp.spawn(run_worker, args=(args, master_func), nprocs=world_size)
|
|
@ -0,0 +1,85 @@
|
|||
import os
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.multiprocessing as mp
|
||||
import torch.distributed.rpc as rpc
|
||||
from torch import autograd
|
||||
from torch.optim import SGD, Adam, RMSprop, Optimizer
|
||||
from colorama import Back, Style
|
||||
|
||||
from colossalai.pipeline.rpc.PipelineBase import FillDrainPipelineEngine, OneFOneBPipelineEngine
|
||||
from colossalai.testing import assert_close
|
||||
from rpc_test_utils import rpc_run, parse_args, RpcTestModel
|
||||
|
||||
|
||||
def run_master(args):
|
||||
torch.manual_seed(100)
|
||||
|
||||
device = args.device
|
||||
stage_num = args.world_size
|
||||
chunk = args.chunk
|
||||
actual_stage_num = stage_num * chunk
|
||||
use_interleave = args.use_interleave
|
||||
use_checkpoint = args.use_checkpoint
|
||||
num_microbatches = args.num_microbatches
|
||||
optimizer_class = globals()[args.optimizer]
|
||||
|
||||
lr = 1e-3
|
||||
|
||||
sample_num = 1024
|
||||
feat_num = 100
|
||||
h = 100
|
||||
batch_size = 1024
|
||||
|
||||
assert sample_num % batch_size == 0
|
||||
batch_num = sample_num // batch_size
|
||||
|
||||
input_sample = torch.randn((sample_num, feat_num), device=device)
|
||||
|
||||
module_partitions = [RpcTestModel(pp_rank, actual_stage_num, feat_num, h) for pp_rank in range(actual_stage_num)]
|
||||
|
||||
engine = OneFOneBPipelineEngine(module_partitions=module_partitions,
|
||||
stage_num=stage_num,
|
||||
num_microbatches=num_microbatches,
|
||||
device=device,
|
||||
chunk=chunk,
|
||||
use_interleave=use_interleave,
|
||||
checkpoint=use_checkpoint)
|
||||
|
||||
engine.initialize_optimizer(optimizer_class, lr=lr)
|
||||
|
||||
_ = engine.forward_backward(input_sample)
|
||||
engine.step()
|
||||
|
||||
cuda_rpc_result = []
|
||||
single_result = []
|
||||
actual_stage_num = engine._get_actual_stage_num()
|
||||
|
||||
# compute parameters after updating in cuda rpc
|
||||
parameters = engine.remote_parameters()
|
||||
for stage_id in range(actual_stage_num):
|
||||
for p in parameters[stage_id]:
|
||||
cuda_rpc_result.append(p)
|
||||
|
||||
# compute forward result and backward grad of parameters just in rank_0
|
||||
test_model = nn.Sequential(*module_partitions).to(device)
|
||||
optimizer: Optimizer = optimizer_class(test_model.parameters(), lr=lr)
|
||||
input_sample = input_sample.requires_grad_()
|
||||
out_val = test_model(input_sample).sum()
|
||||
autograd.backward(out_val)
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
for p in test_model.parameters():
|
||||
single_result.append(p)
|
||||
|
||||
assert len(cuda_rpc_result) == len(single_result)
|
||||
for r_c, r_s in zip(cuda_rpc_result, single_result):
|
||||
assert_close(r_c, r_s, 0.001, 0.001)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
rpc_run(args, run_master)
|
|
@ -7,32 +7,10 @@ import torch.multiprocessing as mp
|
|||
import torch.distributed.rpc as rpc
|
||||
|
||||
from colossalai.pipeline.rpc.PipelineBase import FillDrainPipelineEngine, OneFOneBPipelineEngine
|
||||
from rpc_test_utils import rpc_run, parse_args, RpcTestModel
|
||||
|
||||
|
||||
class TestModel(nn.Module):
|
||||
|
||||
def __init__(self, stage_id, actual_stage_num, feat_num, h) -> None:
|
||||
super().__init__()
|
||||
self.rank = stage_id
|
||||
self.is_last_rank = stage_id == actual_stage_num - 1
|
||||
self.linear_name = f'linear_{stage_id}'
|
||||
if stage_id == 0:
|
||||
setattr(self, self.linear_name, nn.Linear(feat_num, h))
|
||||
elif stage_id == actual_stage_num - 1:
|
||||
setattr(self, self.linear_name, nn.Linear(h, 1))
|
||||
else:
|
||||
setattr(self, self.linear_name, nn.Linear(h, h))
|
||||
|
||||
def forward(self, x) -> torch.Tensor:
|
||||
linear: nn.Module = getattr(self, self.linear_name)
|
||||
out: torch.Tensor = linear(x)
|
||||
|
||||
if self.is_last_rank:
|
||||
out = out.sum()
|
||||
return out
|
||||
|
||||
|
||||
def run_main(args):
|
||||
def run_master(args):
|
||||
torch.manual_seed(100)
|
||||
|
||||
device = args.device
|
||||
|
@ -53,7 +31,7 @@ def run_main(args):
|
|||
|
||||
input_sample = torch.randn((sample_num, feat_num), device=device)
|
||||
|
||||
module_partitions = [TestModel(pp_rank, actual_stage_num, feat_num, h) for pp_rank in range(actual_stage_num)]
|
||||
module_partitions = [RpcTestModel(pp_rank, actual_stage_num, feat_num, h) for pp_rank in range(actual_stage_num)]
|
||||
|
||||
engine = OneFOneBPipelineEngine(module_partitions=module_partitions,
|
||||
stage_num=stage_num,
|
||||
|
@ -66,44 +44,6 @@ def run_main(args):
|
|||
_ = engine.forward_backward(input_sample)
|
||||
|
||||
|
||||
def run_worker(rank, args):
|
||||
os.environ['MASTER_ADDR'] = args.master_addr
|
||||
os.environ['MASTER_PORT'] = args.master_port
|
||||
|
||||
# config rpc
|
||||
# if cuda is used, set_device_map is a must is configured
|
||||
# for cuda is not supported in torch rpc by default
|
||||
options = rpc.TensorPipeRpcBackendOptions(num_worker_threads=args.num_worker_threads)
|
||||
|
||||
world_size = args.world_size
|
||||
for rank_idx in range(world_size):
|
||||
options.set_device_map(f'work{rank_idx}', {rank: rank_idx})
|
||||
|
||||
rpc.init_rpc(name=f'work{rank}', rank=rank, world_size=world_size, rpc_backend_options=options)
|
||||
|
||||
# in rpc mode, only rank 0 is needed to be coded
|
||||
if rank == 0:
|
||||
run_main(args)
|
||||
# barrier here
|
||||
rpc.shutdown()
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--world_size', type=int, default=2)
|
||||
parser.add_argument('--num_microbatches', type=int, default=2)
|
||||
parser.add_argument('--device', type=str, default='cuda')
|
||||
parser.add_argument('--chunk', type=int, default=1)
|
||||
parser.add_argument('--use_checkpoint', action='store_true')
|
||||
parser.add_argument('--use_interleave', action='store_true')
|
||||
parser.add_argument('--master_addr', type=str, default='localhost')
|
||||
parser.add_argument('--master_port', type=str, default='29020')
|
||||
parser.add_argument('--num_worker_threads', type=str, default=128)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
world_size = args.world_size
|
||||
assert args.device in ['cpu', 'cuda'], "device must be cpu or cuda!"
|
||||
mp.spawn(run_worker, args=(args,), nprocs=world_size)
|
||||
rpc_run(args, run_master)
|
||||
|
|
|
@ -9,37 +9,11 @@ from torch import autograd
|
|||
from colorama import Back, Style
|
||||
|
||||
from colossalai.pipeline.rpc.PipelineBase import FillDrainPipelineEngine, OneFOneBPipelineEngine
|
||||
from colossalai.testing import assert_close
|
||||
from rpc_test_utils import rpc_run, parse_args, RpcTestModel
|
||||
|
||||
|
||||
def color_debug(text, prefix=' ', color='blue'):
|
||||
color = color.upper()
|
||||
print(getattr(Back, color), prefix, Style.RESET_ALL, text)
|
||||
|
||||
|
||||
class TestModel(nn.Module):
|
||||
|
||||
def __init__(self, stage_id, actual_stage_num, feat_num, h) -> None:
|
||||
super().__init__()
|
||||
self.rank = stage_id
|
||||
self.is_last_rank = stage_id == actual_stage_num - 1
|
||||
self.linear_name = f'linear_{stage_id}'
|
||||
if stage_id == 0:
|
||||
setattr(self, self.linear_name, nn.Linear(feat_num, h))
|
||||
elif stage_id == actual_stage_num - 1:
|
||||
setattr(self, self.linear_name, nn.Linear(h, 1))
|
||||
else:
|
||||
setattr(self, self.linear_name, nn.Linear(h, h))
|
||||
|
||||
def forward(self, x) -> torch.Tensor:
|
||||
linear: nn.Module = getattr(self, self.linear_name)
|
||||
out: torch.Tensor = linear(x)
|
||||
|
||||
if self.is_last_rank:
|
||||
out = out.sum()
|
||||
return out
|
||||
|
||||
|
||||
def run_main(args):
|
||||
def run_master(args):
|
||||
torch.manual_seed(100)
|
||||
|
||||
device = args.device
|
||||
|
@ -48,6 +22,7 @@ def run_main(args):
|
|||
actual_stage_num = stage_num * chunk
|
||||
use_interleave = args.use_interleave
|
||||
use_checkpoint = args.use_checkpoint
|
||||
num_microbatches = args.num_microbatches
|
||||
|
||||
sample_num = 1024
|
||||
feat_num = 100
|
||||
|
@ -57,11 +32,9 @@ def run_main(args):
|
|||
assert sample_num % batch_size == 0
|
||||
batch_num = sample_num // batch_size
|
||||
|
||||
num_microbatches = stage_num * 1
|
||||
|
||||
input_sample = torch.randn((sample_num, feat_num), device=device)
|
||||
|
||||
module_partitions = [TestModel(pp_rank, actual_stage_num, feat_num, h) for pp_rank in range(actual_stage_num)]
|
||||
module_partitions = [RpcTestModel(pp_rank, actual_stage_num, feat_num, h) for pp_rank in range(actual_stage_num)]
|
||||
|
||||
engine = OneFOneBPipelineEngine(module_partitions=module_partitions,
|
||||
stage_num=stage_num,
|
||||
|
@ -77,74 +50,27 @@ def run_main(args):
|
|||
single_result = []
|
||||
actual_stage_num = engine._get_actual_stage_num()
|
||||
|
||||
# color_debug('cuda rpc forward', 'Test')
|
||||
# print(sum(forward_result[0]))
|
||||
cuda_rpc_result.append(sum(forward_result[0]).item())
|
||||
# color_debug('cuda rpc backward', 'Test')
|
||||
# compute forward result and backward grad of parameters in cuda rpc
|
||||
cuda_rpc_result.append(sum(forward_result[0]))
|
||||
grad = engine.remote_grad()
|
||||
for stage_id in range(actual_stage_num):
|
||||
for p in grad[stage_id]:
|
||||
# print(p.sum())
|
||||
cuda_rpc_result.append(p.sum().item())
|
||||
cuda_rpc_result.append(p)
|
||||
|
||||
# compute forward result and backward grad of parameters just in rank_0
|
||||
test_model = nn.Sequential(*module_partitions).to(device)
|
||||
input_sample = input_sample.requires_grad_()
|
||||
out_val = test_model(input_sample).sum()
|
||||
autograd.backward(out_val)
|
||||
# color_debug('single forward', 'Test')
|
||||
# print(out_val)
|
||||
single_result.append(out_val.item())
|
||||
# color_debug('single backward', 'Test')
|
||||
single_result.append(out_val)
|
||||
for p in test_model.parameters():
|
||||
# print(p.grad.sum())
|
||||
single_result.append(p.grad.sum().item())
|
||||
single_result.append(p.grad)
|
||||
|
||||
cuda_rpc_result = torch.tensor(cuda_rpc_result)
|
||||
single_result = torch.tensor(single_result)
|
||||
distance = (cuda_rpc_result - single_result).abs().sum().item()
|
||||
kappa = round(distance / actual_stage_num, 5)
|
||||
assert kappa < 0.01, f"kappa({kappa}) is too big, PP result may be incorrect!"
|
||||
|
||||
|
||||
def run_worker(rank, args):
|
||||
os.environ['MASTER_ADDR'] = args.master_addr
|
||||
os.environ['MASTER_PORT'] = args.master_port
|
||||
|
||||
# config rpc
|
||||
# if cuda is used, set_device_map is a must is configured
|
||||
# for cuda is not supported in torch rpc by default
|
||||
options = rpc.TensorPipeRpcBackendOptions(num_worker_threads=args.num_worker_threads)
|
||||
|
||||
world_size = args.world_size
|
||||
for rank_idx in range(world_size):
|
||||
options.set_device_map(f'work{rank_idx}', {rank: rank_idx})
|
||||
|
||||
rpc.init_rpc(name=f'work{rank}', rank=rank, world_size=world_size, rpc_backend_options=options)
|
||||
|
||||
# in rpc mode, only rank 0 is needed to be coded
|
||||
if rank == 0:
|
||||
run_main(args)
|
||||
# barrier here
|
||||
rpc.shutdown()
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--world_size', type=int, default=2)
|
||||
parser.add_argument('--num_microbatches', type=int, default=2)
|
||||
parser.add_argument('--chunk', type=int, default=1)
|
||||
parser.add_argument('--use_checkpoint', action='store_true')
|
||||
parser.add_argument('--use_interleave', action='store_true')
|
||||
parser.add_argument('--device', type=str, default='cuda')
|
||||
parser.add_argument('--master_addr', type=str, default='localhost')
|
||||
parser.add_argument('--master_port', type=str, default='29020')
|
||||
parser.add_argument('--num_worker_threads', type=str, default=128)
|
||||
return parser.parse_args()
|
||||
assert len(cuda_rpc_result) == len(single_result)
|
||||
for r_c, r_s in zip(cuda_rpc_result, single_result):
|
||||
assert_close(r_c, r_s, 0.001, 0.001)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
world_size = args.world_size
|
||||
assert args.num_microbatches >= args.world_size, "num_microbatches cannot be fewer than world_size!"
|
||||
assert args.device in ['cpu', 'cuda'], "device must be cpu or cuda!"
|
||||
mp.spawn(run_worker, args=(args,), nprocs=world_size)
|
||||
rpc_run(args, run_master)
|
||||
|
|
Loading…
Reference in New Issue