mirror of https://github.com/hpcaitech/ColossalAI
aibig-modeldata-parallelismdeep-learningdistributed-computingfoundation-modelsheterogeneous-traininghpcinferencelarge-scalemodel-parallelismpipeline-parallelism
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
87 lines
2.5 KiB
87 lines
2.5 KiB
import os |
|
|
|
import torch |
|
import torch.distributed as dist |
|
import torch.multiprocessing as mp |
|
import torch.nn as nn |
|
import torch.optim as optim |
|
from torch.nn.parallel import DistributedDataParallel as DDP |
|
from torch.testing import assert_close |
|
|
|
# example modified from https://pytorch.org/tutorials/intermediate/ddp_tutorial.html |
|
|
|
|
|
def setup(rank, world_size): |
|
os.environ["MASTER_ADDR"] = "localhost" |
|
os.environ["MASTER_PORT"] = "12355" |
|
|
|
# initialize the process group |
|
dist.init_process_group("nccl", rank=rank, world_size=world_size) |
|
|
|
|
|
def cleanup(): |
|
dist.destroy_process_group() |
|
|
|
|
|
class ToyModel(nn.Module): |
|
def __init__(self): |
|
super(ToyModel, self).__init__() |
|
self.net1 = nn.Linear(10, 10) |
|
self.relu = nn.ReLU() |
|
self.net2 = nn.Linear(10, 5) |
|
|
|
def forward(self, x): |
|
return self.net2(self.relu(self.net1(x))) |
|
|
|
|
|
def demo_basic(rank, world_size): |
|
print(f"Running basic DDP example on rank {rank}.") |
|
setup(rank, world_size) |
|
|
|
def get_grads_after_one_iteration(hook=None): |
|
torch.manual_seed(0) |
|
# create model and move it to GPU with id rank |
|
model = ToyModel().to(rank) |
|
|
|
ddp_model = DDP(model, device_ids=[rank]) |
|
|
|
if hook is not None: |
|
ddp_model.register_comm_hook(None, hook) |
|
|
|
loss_fn = nn.MSELoss() |
|
optimizer = optim.SGD(ddp_model.parameters(), lr=0.001) |
|
|
|
optimizer.zero_grad() |
|
outputs = ddp_model(torch.randn(20, 10)) |
|
labels = torch.randn(20, 5).to(rank) |
|
loss_fn(outputs, labels).backward() |
|
optimizer.step() |
|
|
|
torch.distributed.barrier() |
|
|
|
grad_dict = {} |
|
for name, params in ddp_model.named_parameters(): |
|
grad_dict[name] = params.grad |
|
return grad_dict |
|
|
|
from colossalai.quantization.fp8 import fp8_compress_ddp_grad_comm_hook_async, fp8_compress_ddp_grad_comm_hook_sync |
|
|
|
grad_dict = get_grads_after_one_iteration() |
|
for hook in [fp8_compress_ddp_grad_comm_hook_sync, fp8_compress_ddp_grad_comm_hook_async]: |
|
grad_dict_w_hook = get_grads_after_one_iteration(hook) |
|
if dist.get_rank() == 0: |
|
for name in grad_dict: |
|
assert_close(grad_dict[name], grad_dict_w_hook[name], rtol=0.1, atol=0.1) |
|
|
|
cleanup() |
|
|
|
|
|
def run_demo(demo_fn, world_size): |
|
mp.spawn(demo_fn, args=(world_size,), nprocs=world_size, join=True) |
|
|
|
|
|
if __name__ == "__main__": |
|
n_gpus = torch.cuda.device_count() |
|
assert n_gpus >= 2, f"Requires at least 2 GPUs to run, but got {n_gpus}" |
|
world_size = n_gpus |
|
run_demo(demo_basic, world_size)
|
|
|