mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
88 lines
2.5 KiB
88 lines
2.5 KiB
import os
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
import torch.multiprocessing as mp
|
|
import torch.nn as nn
|
|
import torch.optim as optim
|
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
from torch.testing import assert_close
|
|
|
|
# example modified from https://pytorch.org/tutorials/intermediate/ddp_tutorial.html
|
|
|
|
|
|
def setup(rank, world_size):
|
|
os.environ["MASTER_ADDR"] = "localhost"
|
|
os.environ["MASTER_PORT"] = "12355"
|
|
|
|
# initialize the process group
|
|
dist.init_process_group("nccl", rank=rank, world_size=world_size)
|
|
|
|
|
|
def cleanup():
|
|
dist.destroy_process_group()
|
|
|
|
|
|
class ToyModel(nn.Module):
|
|
def __init__(self):
|
|
super(ToyModel, self).__init__()
|
|
self.net1 = nn.Linear(10, 10)
|
|
self.relu = nn.ReLU()
|
|
self.net2 = nn.Linear(10, 5)
|
|
|
|
def forward(self, x):
|
|
return self.net2(self.relu(self.net1(x)))
|
|
|
|
|
|
def demo_basic(rank, world_size):
|
|
print(f"Running basic DDP example on rank {rank}.")
|
|
setup(rank, world_size)
|
|
|
|
def get_grads_after_one_iteration(hook=None):
|
|
torch.manual_seed(0)
|
|
# create model and move it to GPU with id rank
|
|
model = ToyModel().to(rank)
|
|
|
|
ddp_model = DDP(model, device_ids=[rank])
|
|
|
|
if hook is not None:
|
|
ddp_model.register_comm_hook(None, hook)
|
|
|
|
loss_fn = nn.MSELoss()
|
|
optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)
|
|
|
|
optimizer.zero_grad()
|
|
outputs = ddp_model(torch.randn(20, 10))
|
|
labels = torch.randn(20, 5).to(rank)
|
|
loss_fn(outputs, labels).backward()
|
|
optimizer.step()
|
|
|
|
torch.distributed.barrier()
|
|
|
|
grad_dict = {}
|
|
for name, params in ddp_model.named_parameters():
|
|
grad_dict[name] = params.grad
|
|
return grad_dict
|
|
|
|
from colossalai.quantization.fp8 import fp8_compress_ddp_grad_comm_hook_async, fp8_compress_ddp_grad_comm_hook_sync
|
|
|
|
grad_dict = get_grads_after_one_iteration()
|
|
for hook in [fp8_compress_ddp_grad_comm_hook_sync, fp8_compress_ddp_grad_comm_hook_async]:
|
|
grad_dict_w_hook = get_grads_after_one_iteration(hook)
|
|
if dist.get_rank() == 0:
|
|
for name in grad_dict:
|
|
assert_close(grad_dict[name], grad_dict_w_hook[name], rtol=0.1, atol=0.1)
|
|
|
|
cleanup()
|
|
|
|
|
|
def run_demo(demo_fn, world_size):
|
|
mp.spawn(demo_fn, args=(world_size,), nprocs=world_size, join=True)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
n_gpus = torch.cuda.device_count()
|
|
assert n_gpus >= 2, f"Requires at least 2 GPUs to run, but got {n_gpus}"
|
|
world_size = n_gpus
|
|
run_demo(demo_basic, world_size)
|