mirror of https://github.com/hpcaitech/ColossalAI
1196 lines
36 KiB
Python
1196 lines
36 KiB
Python
![]() |
import gc
|
||
|
from copy import deepcopy
|
||
|
from typing import Tuple
|
||
|
|
||
|
import torch
|
||
|
import torch.distributed as dist
|
||
|
import torch.nn as nn
|
||
|
from torch.testing import assert_close
|
||
|
|
||
|
import colossalai
|
||
|
from colossalai.cluster import ProcessGroupMesh
|
||
|
from colossalai.pipeline.p2p import PipelineP2PCommunication
|
||
|
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||
|
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||
|
|
||
|
IN_DIM = 8192
|
||
|
OUT_DIM = 8192
|
||
|
NUM_LAYER = 3
|
||
|
|
||
|
|
||
|
class MlpModel(nn.Module):
|
||
|
def __init__(self, in_dim=IN_DIM, out_dim=OUT_DIM, num_layers=NUM_LAYER):
|
||
|
super().__init__()
|
||
|
self.layers = nn.ModuleList([nn.Linear(in_dim, out_dim, bias=None) for _ in range(num_layers)])
|
||
|
|
||
|
def forward(self, x):
|
||
|
for layer in self.layers:
|
||
|
x = layer(x)
|
||
|
return x
|
||
|
|
||
|
|
||
|
def get_model_numel(model: torch.nn.Module) -> Tuple[int, int]:
|
||
|
num_params = 0
|
||
|
num_params_trainable = 0
|
||
|
for p in model.parameters():
|
||
|
num_params += p.numel()
|
||
|
if p.requires_grad:
|
||
|
num_params_trainable += p.numel()
|
||
|
return num_params, num_params_trainable
|
||
|
|
||
|
|
||
|
# Step1: dx = w*dy
|
||
|
def backward_b(loss, x, model):
|
||
|
print(f"Before bwd b: {torch.cuda.memory_allocated()/1024**3 :.3f} GB")
|
||
|
# print(f"Before x grad {x.grad}")
|
||
|
# for name, param in model.named_parameters():
|
||
|
# print(f"Before bwd b \n param {param}\n param gard {param.grad}\n")
|
||
|
|
||
|
torch.autograd.backward(loss, inputs=x, retain_graph=True)
|
||
|
|
||
|
# for name, param in model.named_parameters():
|
||
|
# print(f"After bwd b \n param {param}\n param gard {param.grad}\n")
|
||
|
|
||
|
# print(f"After x grad {x.grad}")
|
||
|
print(f"After bwd b: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
|
||
|
|
||
|
|
||
|
# Step1: dx = w*dy; for layer not last
|
||
|
def backward_b_not_last(tensors, grad, x, model):
|
||
|
print(f"Before bwd b: {torch.cuda.memory_allocated()/1024**3 :.3f} GB")
|
||
|
torch.autograd.backward(tensors=tensors, grad_tensors=grad, inputs=x, retain_graph=True)
|
||
|
print(f"After bwd b: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
|
||
|
|
||
|
|
||
|
def backward_w(loss, model):
|
||
|
print(f"Before bwd w: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
|
||
|
|
||
|
# for name, param in model.named_parameters():
|
||
|
# print(f"Before bwd w \n param {param}\n param gard {param.grad}\n")
|
||
|
|
||
|
torch.autograd.backward(loss, inputs=list(model.parameters()))
|
||
|
|
||
|
# for name, param in model.named_parameters():
|
||
|
# print(f"After bwd w \n param {param}\n param gard {param.grad}\n")
|
||
|
|
||
|
print(f"After bwd w: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
|
||
|
|
||
|
|
||
|
# Step2: dummy dw = x*dy
|
||
|
def backward_w_not_last(tensors, grad, model):
|
||
|
print(f"Before bwd w: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
|
||
|
torch.autograd.backward(tensors=tensors, grad_tensors=grad, inputs=list(model.parameters()))
|
||
|
print(f"After bwd w: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
|
||
|
|
||
|
|
||
|
def test_dx_dw_split():
|
||
|
device = "cuda:0"
|
||
|
model = nn.Linear(8, 8, bias=None).to(device=device)
|
||
|
print(f"model numel {get_model_numel(model)}") # 4GB
|
||
|
x = torch.rand(8, 8).to(device=device)
|
||
|
ref_model = deepcopy(model)
|
||
|
ref_x = x.clone()
|
||
|
|
||
|
# first step
|
||
|
x.requires_grad_()
|
||
|
loss = model(x).sum()
|
||
|
backward_b(loss, x, model)
|
||
|
for p in model.parameters():
|
||
|
assert p.grad is None
|
||
|
assert x.grad is not None
|
||
|
backward_w(loss, model)
|
||
|
for p in model.parameters():
|
||
|
assert p.grad is not None
|
||
|
|
||
|
# # second step
|
||
|
# loss = model(x).sum()
|
||
|
# backward_b(loss, x, model)
|
||
|
# backward_w(loss, model)
|
||
|
|
||
|
ref_x.requires_grad_()
|
||
|
ref_loss = ref_model(ref_x).sum()
|
||
|
ref_loss.backward()
|
||
|
|
||
|
assert torch.equal(x.grad, ref_x.grad)
|
||
|
for p1, p2 in zip(model.parameters(), ref_model.parameters()):
|
||
|
assert torch.equal(p1.grad, p2.grad)
|
||
|
|
||
|
|
||
|
def test_double_dx_dw_split_nsync():
|
||
|
device = "cuda:0"
|
||
|
model = nn.Linear(8, 8, bias=None).to(device=device)
|
||
|
# print(f"model numel {get_model_numel(model)}") # 4GB
|
||
|
x1 = torch.rand(8, 8).to(device=device)
|
||
|
x2 = torch.rand(8, 8).to(device=device)
|
||
|
ref_model = deepcopy(model)
|
||
|
ref_x1 = x1.clone()
|
||
|
ref_x2 = x2.clone()
|
||
|
|
||
|
# first step
|
||
|
x1.requires_grad_()
|
||
|
x2.requires_grad_()
|
||
|
ref_x1.requires_grad_()
|
||
|
ref_x2.requires_grad_()
|
||
|
|
||
|
# loss for dx_dw bwd
|
||
|
loss1 = model(x1).sum()
|
||
|
loss2 = model(x2).sum()
|
||
|
|
||
|
# loss for common bwd
|
||
|
ref_loss1 = ref_model(ref_x1).sum()
|
||
|
ref_loss2 = ref_model(ref_x2).sum()
|
||
|
|
||
|
# dx1
|
||
|
backward_b(loss1, x1, model)
|
||
|
for p in model.parameters():
|
||
|
assert p.grad is None
|
||
|
assert x1.grad is not None
|
||
|
|
||
|
# dx2
|
||
|
backward_b(loss2, x2, model)
|
||
|
|
||
|
# dw1
|
||
|
backward_w(loss1, model)
|
||
|
for p in model.parameters():
|
||
|
assert p.grad is not None
|
||
|
|
||
|
# common bwd 1
|
||
|
ref_loss1.backward()
|
||
|
|
||
|
# assert dx1 & dw1 == bwd 1
|
||
|
assert_close(x1.grad, ref_x1.grad)
|
||
|
for p1, p2 in zip(model.parameters(), ref_model.parameters()):
|
||
|
assert_close(p1, p2)
|
||
|
assert_close(p1.grad, p2.grad)
|
||
|
|
||
|
# dw2
|
||
|
backward_w(loss2, model)
|
||
|
|
||
|
# common bwd 2
|
||
|
ref_loss2.backward()
|
||
|
|
||
|
# assert dx2 & dw2 == bwd 2
|
||
|
assert_close(x2.grad, ref_x2.grad)
|
||
|
for p1, p2 in zip(model.parameters(), ref_model.parameters()):
|
||
|
print(f"bwd2:\n p1 {p1.grad},\n p2 {p2.grad}\n")
|
||
|
assert_close(p1, p2)
|
||
|
assert_close(p1.grad, p2.grad)
|
||
|
|
||
|
|
||
|
def test_double_dx_dw_split_sync():
|
||
|
device = "cuda:0"
|
||
|
model = nn.Linear(8, 8, bias=None).to(device=device)
|
||
|
# print(f"model numel {get_model_numel(model)}") # 4GB
|
||
|
x1 = torch.rand(8, 8).to(device=device)
|
||
|
x2 = torch.rand(8, 8).to(device=device)
|
||
|
|
||
|
# x1 = torch.ones(8, 8).to(device=device)
|
||
|
# x2 = torch.ones(8, 8).to(device=device)
|
||
|
|
||
|
ref_model = deepcopy(model)
|
||
|
ref_x1 = x1.clone()
|
||
|
ref_x2 = x2.clone()
|
||
|
|
||
|
x1.requires_grad_()
|
||
|
x2.requires_grad_()
|
||
|
ref_x1.requires_grad_()
|
||
|
ref_x2.requires_grad_()
|
||
|
|
||
|
############
|
||
|
# step1:
|
||
|
############
|
||
|
print(f"Step1\n")
|
||
|
|
||
|
# loss1
|
||
|
loss1 = model(x1).sum()
|
||
|
|
||
|
# ref_loss1
|
||
|
ref_loss1 = ref_model(ref_x1).sum()
|
||
|
|
||
|
# dx1
|
||
|
backward_b(loss1, x1, model)
|
||
|
for p in model.parameters():
|
||
|
assert p.grad is None
|
||
|
assert x1.grad is not None
|
||
|
|
||
|
# dw1
|
||
|
backward_w(loss1, model)
|
||
|
for p in model.parameters():
|
||
|
assert p.grad is not None
|
||
|
|
||
|
# common bwd 1
|
||
|
ref_loss1.backward()
|
||
|
|
||
|
# assert dx1 & dw1 == bwd 1
|
||
|
assert_close(x1.grad, ref_x1.grad)
|
||
|
for p1, p2 in zip(model.parameters(), ref_model.parameters()):
|
||
|
assert_close(p1, p2)
|
||
|
assert_close(p1.grad, p2.grad)
|
||
|
|
||
|
############
|
||
|
# step2:
|
||
|
############
|
||
|
print(f"Step2\n")
|
||
|
|
||
|
# loss2
|
||
|
loss2 = model(x2).sum()
|
||
|
|
||
|
# ref_loss2
|
||
|
ref_loss2 = ref_model(ref_x2).sum()
|
||
|
|
||
|
for p1, p2 in zip(model.parameters(), ref_model.parameters()):
|
||
|
# print(f"bwd2:\n p1 {p1.grad},\n p2 {p2.grad}\n")
|
||
|
assert_close(p1, p2)
|
||
|
assert_close(p1.grad, p2.grad)
|
||
|
|
||
|
# dx2
|
||
|
backward_b(loss2, x2, model)
|
||
|
|
||
|
# dw2
|
||
|
backward_w(loss2, model)
|
||
|
|
||
|
# common bwd 2
|
||
|
ref_loss2.backward()
|
||
|
|
||
|
# assert dx2 & dw2 == bwd 2
|
||
|
assert_close(x2.grad, ref_x2.grad)
|
||
|
for p1, p2 in zip(model.parameters(), ref_model.parameters()):
|
||
|
# print(f"bwd2:\n p1 {p1.grad},\n p2 {p2.grad}\n")
|
||
|
assert_close(p1, p2)
|
||
|
assert_close(p1.grad, p2.grad)
|
||
|
|
||
|
|
||
|
def deallocate_output_tensor(out):
|
||
|
"""Pseudo-deallocate (i.e., set to scalar) the output tensor's '.data' field.
|
||
|
|
||
|
This method should be called right after the output tensor has been
|
||
|
sent to the next pipeline stage. At this point, the output tensor is
|
||
|
only useful for its '.grad_fn' field, and not its '.data'.
|
||
|
"""
|
||
|
assert isinstance(out, torch.Tensor), "expected Tensor, found %s." % type(out).__name__
|
||
|
assert out._base is None, "counter-productive to free a view of another tensor."
|
||
|
out.data = torch.empty(
|
||
|
(1,),
|
||
|
device=out.device,
|
||
|
dtype=out.dtype,
|
||
|
)
|
||
|
|
||
|
|
||
|
# del loss and x
|
||
|
def mem_dx_dw():
|
||
|
device = "cuda:0"
|
||
|
# model = nn.Linear(IN_DIM, OUT_DIM, bias=None).to(device=device)
|
||
|
print(f"Before init Model: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
|
||
|
model = MlpModel().to(device=device)
|
||
|
print(f"model numel {get_model_numel(model)}") # 4GB
|
||
|
print(f"After init Model: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
|
||
|
|
||
|
print(f"Before init x1&2&3: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
|
||
|
x1 = torch.rand(IN_DIM, OUT_DIM).to(device=device)
|
||
|
x2 = torch.rand(IN_DIM, OUT_DIM).to(device=device)
|
||
|
x3 = torch.rand(IN_DIM, OUT_DIM).to(device=device)
|
||
|
|
||
|
x1.requires_grad_()
|
||
|
x2.requires_grad_()
|
||
|
x3.requires_grad_()
|
||
|
print(f"After init x1&2&3: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
|
||
|
|
||
|
############
|
||
|
# step1:
|
||
|
############
|
||
|
print(f"\nStep1")
|
||
|
|
||
|
# loss1
|
||
|
print(f"Before Fwd x1: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
|
||
|
loss1 = model(x1).sum()
|
||
|
print(f"After Fwd x1: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
|
||
|
|
||
|
print(f"Before loss1: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
|
||
|
print(f"After loss1: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
|
||
|
|
||
|
# dx1
|
||
|
backward_b(loss1, x1, model)
|
||
|
|
||
|
# dw1
|
||
|
backward_w(loss1, model)
|
||
|
|
||
|
# deallocate_output_tensor(x1)
|
||
|
# deallocate_output_tensor(loss1)
|
||
|
del loss1, x1
|
||
|
# del x1
|
||
|
# del y1
|
||
|
print(f"After del x1&y1: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
|
||
|
|
||
|
############
|
||
|
# step2:
|
||
|
############
|
||
|
print(f"\nStep2")
|
||
|
|
||
|
# loss2
|
||
|
loss2 = model(x2).sum()
|
||
|
|
||
|
# dx2
|
||
|
backward_b(loss2, x2, model)
|
||
|
|
||
|
# dw2
|
||
|
backward_w(loss2, model)
|
||
|
|
||
|
# deallocate_output_tensor(x2)
|
||
|
# deallocate_output_tensor(loss2)
|
||
|
del x2, loss2
|
||
|
# del x2
|
||
|
# del y2
|
||
|
print(f"After del x2&y2: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
|
||
|
|
||
|
############
|
||
|
# step3:
|
||
|
############
|
||
|
print(f"\nStep3")
|
||
|
|
||
|
# loss3
|
||
|
loss3 = model(x3).sum()
|
||
|
|
||
|
# dx2
|
||
|
backward_b(loss3, x3, model)
|
||
|
|
||
|
# dw2
|
||
|
backward_w(loss3, model)
|
||
|
|
||
|
# deallocate_output_tensor(x3)
|
||
|
# deallocate_output_tensor(loss3)
|
||
|
# del x3
|
||
|
# del y3
|
||
|
del x3, loss3
|
||
|
|
||
|
print(f"After del x3&y3: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
|
||
|
|
||
|
param_ids = [id(p) for p in model.parameters()]
|
||
|
for obj in gc.get_objects():
|
||
|
if torch.is_tensor(obj) and id(obj) not in param_ids:
|
||
|
print(obj)
|
||
|
|
||
|
|
||
|
# del activation
|
||
|
def activation_dx_dw():
|
||
|
device = "cuda:0"
|
||
|
# model = nn.Linear(IN_DIM, OUT_DIM, bias=None).to(device=device)
|
||
|
print(f"Before init Model: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
|
||
|
model = MlpModel().to(device=device)
|
||
|
x1 = torch.rand(IN_DIM, OUT_DIM).to(device=device)
|
||
|
x2 = torch.rand(IN_DIM, OUT_DIM).to(device=device)
|
||
|
x3 = torch.rand(IN_DIM, OUT_DIM).to(device=device)
|
||
|
|
||
|
x1.requires_grad_()
|
||
|
x2.requires_grad_()
|
||
|
x3.requires_grad_()
|
||
|
print(f"After init Model, x1,x2,x3: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
|
||
|
|
||
|
# activations = {}
|
||
|
# def register_hooks(module):
|
||
|
# def activation_hook(module, input, output):
|
||
|
# activations[f"{module.__class__.__name__}_{id(module)}"] = output.detach()
|
||
|
# def bwd_hook(module, grad_input, grad_output):
|
||
|
# del activations[f"{module.__class__.__name__}_{id(module)}"]
|
||
|
# module.register_forward_hook(activation_hook)
|
||
|
# module.register_backward_hook(bwd_hook)
|
||
|
|
||
|
# model.apply(register_hooks)
|
||
|
|
||
|
############
|
||
|
# step1:
|
||
|
############
|
||
|
print(f"\nStep1")
|
||
|
|
||
|
# loss1
|
||
|
output1 = model(x1)
|
||
|
loss1 = output1.sum()
|
||
|
|
||
|
# dx1
|
||
|
backward_b(loss1, x1, model)
|
||
|
|
||
|
# for name, p in model.named_parameters():
|
||
|
# print(f"p grad {p.grad}")
|
||
|
|
||
|
# dw1
|
||
|
backward_w(loss1, model)
|
||
|
|
||
|
# for name, p in model.named_parameters():
|
||
|
# del p.grad
|
||
|
|
||
|
# del loss1, x1
|
||
|
del loss1, x1, output1
|
||
|
print(f"After del : {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
|
||
|
|
||
|
############
|
||
|
# step2:
|
||
|
############
|
||
|
print(f"\nStep2")
|
||
|
|
||
|
# loss2
|
||
|
output2 = model(x2)
|
||
|
loss2 = output2.sum()
|
||
|
|
||
|
# dx2
|
||
|
backward_b(loss2, x2, model)
|
||
|
|
||
|
# for name, p in model.named_parameters():
|
||
|
# print(f"p grad {p.grad}")
|
||
|
|
||
|
# dw2
|
||
|
backward_w(loss2, model)
|
||
|
|
||
|
# for name, p in model.named_parameters():
|
||
|
# print(f"p grad {p.grad}")
|
||
|
|
||
|
# del x2, loss2
|
||
|
del x2, loss2, output2
|
||
|
print(f"After del : {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
|
||
|
|
||
|
############
|
||
|
# step3:
|
||
|
############
|
||
|
print(f"\nStep3")
|
||
|
|
||
|
# loss3
|
||
|
output3 = model(x3)
|
||
|
loss3 = output3.sum()
|
||
|
|
||
|
# dx2
|
||
|
backward_b(loss3, x3, model)
|
||
|
|
||
|
# dw2
|
||
|
backward_w(loss3, model)
|
||
|
|
||
|
# del x3, loss3
|
||
|
del x3, loss3, output3
|
||
|
|
||
|
print(f"After del : {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
|
||
|
|
||
|
|
||
|
def model_chunk_dx_dw():
|
||
|
device = "cuda:0"
|
||
|
num_layers = 4
|
||
|
print(f"Before init Model: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
|
||
|
model = MlpModel(in_dim=4096, out_dim=4096, num_layers=num_layers).to(device=device)
|
||
|
input = torch.rand(4096, 4096, requires_grad=True).to(device=device)
|
||
|
|
||
|
input_base = input.clone()
|
||
|
|
||
|
model_base = deepcopy(model)
|
||
|
|
||
|
##########################
|
||
|
# Fwd bwd for dx dw
|
||
|
##########################
|
||
|
|
||
|
model_chunk_0 = torch.nn.Sequential() # for layer 1 & 2
|
||
|
model_chunk_1 = torch.nn.Sequential() # for layer 3 & 4
|
||
|
|
||
|
for idx, sub_model in enumerate(model.layers):
|
||
|
if idx < 2:
|
||
|
model_chunk_0.append(sub_model)
|
||
|
else:
|
||
|
model_chunk_1.append(sub_model)
|
||
|
|
||
|
print(f"After init Model & input: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
|
||
|
|
||
|
##########################
|
||
|
# Step1:chunk 0 fwd
|
||
|
##########################
|
||
|
output1 = model_chunk_0(input)
|
||
|
|
||
|
# detach output1; then output1 for chunk 0, output1_dt for chunk 1;
|
||
|
output1_dt = output1.detach()
|
||
|
output1_dt.requires_grad_()
|
||
|
print(f"After chunk0 fwd (include detach output1): {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
|
||
|
|
||
|
##########################
|
||
|
# Step2:chunk 1 fwd
|
||
|
##########################
|
||
|
output2 = model_chunk_1(output1_dt)
|
||
|
|
||
|
print(f"After chunk1 fwd: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
|
||
|
|
||
|
##########################
|
||
|
# Step3:chunk 1 bwd b: dx=w*dy & bwd w:dw=x*dy
|
||
|
##########################
|
||
|
loss = output2.mean()
|
||
|
backward_b(loss, output1_dt, model_chunk_1)
|
||
|
backward_w(loss, model_chunk_1)
|
||
|
|
||
|
print(f"After chunk1 bwd b & w: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
|
||
|
|
||
|
##########################
|
||
|
# Step4:chunk 0 bwd b: dx=w*dy & bwd w:dw=x*dy
|
||
|
##########################
|
||
|
# dx = w*dy
|
||
|
backward_b_not_last(tensors=output1, grad=output1_dt.grad, x=input, model=model_chunk_0)
|
||
|
backward_w_not_last(tensors=output1, grad=output1_dt.grad, model=model_chunk_0)
|
||
|
|
||
|
print(f"After chunk0 bwd b & w: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
|
||
|
|
||
|
##########################
|
||
|
# Fwd bwd for base
|
||
|
##########################
|
||
|
|
||
|
# fwd & bwd
|
||
|
output_base = model_base(input_base)
|
||
|
|
||
|
loss_base = output_base.mean()
|
||
|
|
||
|
loss_base.backward()
|
||
|
print(f"After base fwd & bwd: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
|
||
|
|
||
|
##########################
|
||
|
# Assert param
|
||
|
##########################
|
||
|
|
||
|
assert_close(output2, output_base)
|
||
|
assert_close(output2.grad, output_base.grad)
|
||
|
|
||
|
for p1, p2 in zip(model.parameters(), model_base.parameters()):
|
||
|
assert_close(p1, p2)
|
||
|
assert_close(p1.grad, p2.grad)
|
||
|
|
||
|
del output1, output1_dt, output2, loss, loss_base, output_base
|
||
|
print(f"After del: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
|
||
|
|
||
|
|
||
|
def model_chunk_dx_dw_communication(
|
||
|
rank: int,
|
||
|
world_size: int,
|
||
|
port: int,
|
||
|
):
|
||
|
# init dist
|
||
|
colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost")
|
||
|
pg_mesh = ProcessGroupMesh(world_size)
|
||
|
stage_manager = PipelineStageManager(pg_mesh, pipeline_axis=0, enable_interleave=True, num_model_chunks=2)
|
||
|
rank = dist.get_rank()
|
||
|
comm = PipelineP2PCommunication(stage_manager, overlap_p2p=False)
|
||
|
|
||
|
print(f"{stage_manager.get_rank()}")
|
||
|
|
||
|
# init model and input
|
||
|
num_layers = 4
|
||
|
print(f"Before init Model: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};")
|
||
|
model = MlpModel(in_dim=4096, out_dim=4096, num_layers=num_layers).to(rank)
|
||
|
input = torch.rand(4096, 4096, requires_grad=True).to(rank)
|
||
|
|
||
|
input_base = input.clone()
|
||
|
model_base = deepcopy(model)
|
||
|
|
||
|
if rank == 0:
|
||
|
model_chunk_0 = torch.nn.Sequential().to(rank) # for layer 1 & 2 on rank0
|
||
|
for idx, sub_model in enumerate(model.layers):
|
||
|
if idx < 2:
|
||
|
model_chunk_0.append(sub_model)
|
||
|
else:
|
||
|
model_chunk_1 = torch.nn.Sequential().to(rank) # for layer 3 & 4 on rank1
|
||
|
for idx, sub_model in enumerate(model.layers):
|
||
|
if idx >= 2:
|
||
|
model_chunk_1.append(sub_model)
|
||
|
|
||
|
print(
|
||
|
f"After init Model & input: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};"
|
||
|
)
|
||
|
|
||
|
##########################
|
||
|
# Step1:chunk 0 fwd
|
||
|
##########################
|
||
|
if rank == 0:
|
||
|
output1 = model_chunk_0(input)
|
||
|
# detach output1; then output1 for chunk 0, output1_dt for chunk 1;
|
||
|
# output1_dt_rank0 = output1.detach()
|
||
|
# output1_dt_rank0.requires_grad_()
|
||
|
print(
|
||
|
f"After chunk0 fwd (include detach output1): {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};"
|
||
|
)
|
||
|
# send y(output1_dt) to next stage
|
||
|
comm.send_forward(output1, stage_manager.get_next_rank())
|
||
|
|
||
|
##########################
|
||
|
# Step2:chunk 1 fwd
|
||
|
##########################
|
||
|
if rank == 1:
|
||
|
# recv y(output1_dt) from prev stage
|
||
|
output1_dt_rank1, wait_handles = comm.recv_forward(stage_manager.get_prev_rank())
|
||
|
output1_dt_rank1.requires_grad_()
|
||
|
output2 = model_chunk_1(output1_dt_rank1)
|
||
|
|
||
|
print(
|
||
|
f"After chunk1 fwd: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};"
|
||
|
)
|
||
|
|
||
|
##########################
|
||
|
# Step3:chunk 1 on device_1 bwd b: dx=w*dy & bwd w:dw=x*dy
|
||
|
##########################
|
||
|
if rank == 1:
|
||
|
loss = output2.mean()
|
||
|
backward_b(loss, output1_dt_rank1, model_chunk_1)
|
||
|
backward_w(loss, model_chunk_1)
|
||
|
|
||
|
print(f"After chunk1 bwd b & w: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
|
||
|
# send bwd output1_dt_rank1 from rank1 to rank 0
|
||
|
comm.send_backward(output1_dt_rank1.grad, stage_manager.get_prev_rank())
|
||
|
##########################
|
||
|
# Step4:chunk 0 on device_0 bwd b: dx=w*dy & bwd w:dw=x*dy
|
||
|
##########################
|
||
|
|
||
|
if rank == 0:
|
||
|
# recv bwd output1_dt_rank1 from rank1 to rank 0
|
||
|
output1_dt_rank0_grad, _ = comm.recv_backward(stage_manager.get_next_rank())
|
||
|
|
||
|
backward_b_not_last(tensors=output1, grad=output1_dt_rank0_grad, x=input, model=model_chunk_0)
|
||
|
backward_w_not_last(tensors=output1, grad=output1_dt_rank0_grad, model=model_chunk_0)
|
||
|
|
||
|
print(f"After chunk0 bwd b & w: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
|
||
|
|
||
|
##########################
|
||
|
# Fwd bwd for base
|
||
|
##########################
|
||
|
# fwd & bwd
|
||
|
output_base = model_base(input_base)
|
||
|
loss_base = output_base.mean()
|
||
|
loss_base.backward()
|
||
|
print(f"After base fwd & bwd: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
|
||
|
|
||
|
##########################
|
||
|
# Assert param
|
||
|
##########################
|
||
|
# assert output
|
||
|
if rank == 1:
|
||
|
assert_close(output2, output_base)
|
||
|
assert_close(output2.grad, output_base.grad)
|
||
|
|
||
|
# assert model param & grad
|
||
|
if rank == 0:
|
||
|
count = 0
|
||
|
for (chunk_name, chunk_param), (base_name, base_param) in zip(
|
||
|
model_chunk_0.named_parameters(), model_base.named_parameters()
|
||
|
):
|
||
|
if count < 2:
|
||
|
assert_close(chunk_param, base_param)
|
||
|
assert_close(chunk_param.grad, base_param.grad)
|
||
|
count += 1
|
||
|
if rank == 1:
|
||
|
count = 0
|
||
|
for (chunk_name, chunk_param), (base_name, base_param) in zip(
|
||
|
model_chunk_1.named_parameters(), model_base.named_parameters()
|
||
|
):
|
||
|
if count >= 2:
|
||
|
assert_close(chunk_param, base_param)
|
||
|
assert_close(chunk_param.grad, base_param.grad)
|
||
|
count += 1
|
||
|
# clean memory
|
||
|
if rank == 0:
|
||
|
del output1, output1_dt_rank0_grad
|
||
|
if rank == 1:
|
||
|
del output2, loss, output1_dt_rank1
|
||
|
del loss_base, output_base
|
||
|
print(f"After del: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};")
|
||
|
|
||
|
|
||
|
# Return: output, loss
|
||
|
def schedule_f(
|
||
|
stage_manager: PipelineStageManager,
|
||
|
comm: PipelineP2PCommunication,
|
||
|
input: torch.Tensor,
|
||
|
model_chunk: torch.nn.ModuleList,
|
||
|
model_chunk_id: int,
|
||
|
):
|
||
|
# chunk_id == 0
|
||
|
if model_chunk_id == 0:
|
||
|
# recv fwd from prev
|
||
|
if stage_manager.is_first_stage(ignore_chunk=True):
|
||
|
input = input # get local input
|
||
|
else:
|
||
|
prev_rank = stage_manager.get_prev_rank()
|
||
|
input, wait_handles = comm.recv_forward(prev_rank)
|
||
|
|
||
|
# fwd step
|
||
|
output = model_chunk[model_chunk_id](input)
|
||
|
|
||
|
# send fwd to next
|
||
|
if stage_manager.is_last_stage(ignore_chunk=True):
|
||
|
return input, output, None # return local output
|
||
|
else:
|
||
|
next_rank = stage_manager.get_next_rank()
|
||
|
comm.send_forward(output, next_rank)
|
||
|
|
||
|
# chunk_id == 1
|
||
|
if model_chunk_id == 1:
|
||
|
# recv fwd from next
|
||
|
if stage_manager.is_last_stage(ignore_chunk=True):
|
||
|
input = input # get local input
|
||
|
else:
|
||
|
next_rank = stage_manager.get_next_rank()
|
||
|
input, wait_handles = comm.recv_forward(next_rank)
|
||
|
|
||
|
# fwd step
|
||
|
output = model_chunk[model_chunk_id](input)
|
||
|
|
||
|
# send fwd to prev
|
||
|
if stage_manager.is_first_stage(ignore_chunk=True):
|
||
|
loss = output.mean()
|
||
|
return input, output, loss # return local output
|
||
|
else:
|
||
|
prev_rank = stage_manager.get_prev_rank()
|
||
|
comm.send_forward(output, prev_rank)
|
||
|
return input, output, None
|
||
|
|
||
|
|
||
|
def schedule_b(
|
||
|
stage_manager: PipelineStageManager,
|
||
|
comm: PipelineP2PCommunication,
|
||
|
input: torch.Tensor, # x
|
||
|
output: torch.Tensor, # y
|
||
|
output_grad: torch.Tensor, # dy
|
||
|
model_chunk: torch.nn.ModuleList,
|
||
|
model_chunk_id: int,
|
||
|
):
|
||
|
# chunk_id == 0
|
||
|
if model_chunk_id == 0:
|
||
|
|
||
|
# recv bwd from next
|
||
|
if stage_manager.is_last_stage(ignore_chunk=True):
|
||
|
output_grad = output_grad # get dy from local
|
||
|
else:
|
||
|
next_rank = stage_manager.get_next_rank()
|
||
|
output_grad, _ = comm.recv_backward(next_rank)
|
||
|
|
||
|
# bwd step
|
||
|
backward_b_not_last(tensors=output, grad=output_grad, x=input, model=model_chunk[model_chunk_id])
|
||
|
|
||
|
backward_w_not_last(tensors=output, grad=output_grad, model=model_chunk[model_chunk_id])
|
||
|
|
||
|
# send bwd to prev
|
||
|
if stage_manager.is_first_stage(ignore_chunk=True):
|
||
|
return input.grad
|
||
|
else:
|
||
|
prev_rank = stage_manager.get_prev_rank()
|
||
|
comm.send_backward(input.grad, prev_rank)
|
||
|
|
||
|
# chunk_id == 1
|
||
|
if model_chunk_id == 1:
|
||
|
# recv bwd from prev
|
||
|
if stage_manager.is_first_stage(ignore_chunk=True):
|
||
|
output_grad = output_grad
|
||
|
else:
|
||
|
prev_rank = stage_manager.get_prev_rank()
|
||
|
# print(f"prev_rank {prev_rank} curr rank {stage_manager.get_rank()}")
|
||
|
output_grad, _ = comm.recv_backward(next_rank=prev_rank)
|
||
|
|
||
|
# bwd step
|
||
|
# print(f"Before input grad {input.grad}")
|
||
|
# for name, param in model_chunk[model_chunk_id].named_parameters():
|
||
|
# print(f"Before {name} grad {param.grad}")
|
||
|
|
||
|
if stage_manager.is_first_stage(ignore_chunk=True):
|
||
|
backward_b(loss=output_grad, x=input, model=model_chunk[model_chunk_id])
|
||
|
backward_w(loss=output_grad, model=model_chunk[model_chunk_id])
|
||
|
else:
|
||
|
# commom bwd step
|
||
|
# print(f"output_grad {output_grad}")
|
||
|
backward_b_not_last(tensors=output, grad=output_grad, x=input, model=model_chunk[model_chunk_id])
|
||
|
backward_w_not_last(tensors=output, grad=output_grad, model=model_chunk[model_chunk_id])
|
||
|
|
||
|
# print(f"After input grad {input.grad}")
|
||
|
# for name, param in model_chunk[model_chunk_id].named_parameters():
|
||
|
# print(f"After {name} grad {param.grad}")
|
||
|
|
||
|
# send bwd to next
|
||
|
if stage_manager.is_last_stage(ignore_chunk=True):
|
||
|
return input.grad
|
||
|
else:
|
||
|
next_rank = stage_manager.get_next_rank()
|
||
|
comm.send_backward(input.grad, next_rank)
|
||
|
|
||
|
return input.grad
|
||
|
|
||
|
|
||
|
def schedule_w():
|
||
|
pass
|
||
|
|
||
|
|
||
|
def model_chunk_dx_dw_comm_interleaved(
|
||
|
rank: int,
|
||
|
world_size: int,
|
||
|
port: int,
|
||
|
):
|
||
|
# init dist
|
||
|
colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost")
|
||
|
pg_mesh = ProcessGroupMesh(world_size)
|
||
|
stage_manager = PipelineStageManager(pg_mesh, pipeline_axis=0, enable_interleave=True, num_model_chunks=world_size)
|
||
|
rank = dist.get_rank()
|
||
|
comm = PipelineP2PCommunication(stage_manager, overlap_p2p=False)
|
||
|
|
||
|
# init model and input
|
||
|
num_layers = 8
|
||
|
in_dim = out_dim = 2048
|
||
|
print(f"Before init Model: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};")
|
||
|
model = MlpModel(in_dim=in_dim, out_dim=out_dim, num_layers=num_layers).to(rank)
|
||
|
input0 = torch.rand(in_dim, out_dim, requires_grad=True).to(rank)
|
||
|
|
||
|
input_base = input0.clone()
|
||
|
model_base = deepcopy(model)
|
||
|
|
||
|
if rank == 0:
|
||
|
# layer 0 & 7 to chunk 0 on rank0
|
||
|
chunk_0 = torch.nn.ModuleList().to(rank)
|
||
|
for idx, sub_model in enumerate(model.layers):
|
||
|
if idx == 0 or idx == 7:
|
||
|
chunk_0.append(sub_model)
|
||
|
elif rank == 1:
|
||
|
# layer 1 & 6 to chunk 1 on rank1
|
||
|
chunk_1 = torch.nn.ModuleList().to(rank)
|
||
|
for idx, sub_model in enumerate(model.layers):
|
||
|
if idx == 1 or idx == 6:
|
||
|
chunk_1.append(sub_model)
|
||
|
elif rank == 2:
|
||
|
# layer 2 & 5 to chunk 2 on rank2
|
||
|
chunk_2 = torch.nn.ModuleList().to(rank)
|
||
|
for idx, sub_model in enumerate(model.layers):
|
||
|
if idx == 2 or idx == 5:
|
||
|
chunk_2.append(sub_model)
|
||
|
else:
|
||
|
# layer 3 & 4 to chunk 3 on rank3
|
||
|
chunk_3 = torch.nn.Sequential().to(rank)
|
||
|
for idx, sub_model in enumerate(model.layers):
|
||
|
if idx == 3 or idx == 4:
|
||
|
chunk_3.append(sub_model)
|
||
|
|
||
|
# # test checkpoint
|
||
|
# check_fn = lambda submodule: isinstance(submodule, (Linear))
|
||
|
# non_reentrant_wrapper = partial(
|
||
|
# checkpoint_wrapper,
|
||
|
# # checkpoint_impl=CheckpointImpl.NO_REENTRANT,
|
||
|
# checkpoint_impl=CheckpointImpl.REENTRANT,
|
||
|
# )
|
||
|
# apply_activation_checkpointing(
|
||
|
# model, checkpoint_wrapper_fn=non_reentrant_wrapper, check_fn=check_fn
|
||
|
# )
|
||
|
|
||
|
print(
|
||
|
f"After init Model & input: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};"
|
||
|
)
|
||
|
# set_checkpoint_early_stop(False)
|
||
|
# buffer use to save input and output
|
||
|
|
||
|
##########################
|
||
|
# Step1: fwd
|
||
|
##########################
|
||
|
######
|
||
|
# fwd 1->4
|
||
|
######
|
||
|
# chunk 0 id 0 (layer 0) fwd
|
||
|
if rank == 0:
|
||
|
chunk_id = 0
|
||
|
input0, output0, _ = schedule_f(
|
||
|
stage_manager=stage_manager,
|
||
|
comm=comm,
|
||
|
input=input0,
|
||
|
model_chunk=chunk_0,
|
||
|
model_chunk_id=chunk_id,
|
||
|
)
|
||
|
print(
|
||
|
f"chunk 0 id 0 (layer 0)fwd: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};"
|
||
|
)
|
||
|
|
||
|
# chunk 1 id 0 (layer 1) fwd
|
||
|
if rank == 1:
|
||
|
chunk_id = 0
|
||
|
input1, output1, _ = schedule_f(
|
||
|
stage_manager=stage_manager,
|
||
|
comm=comm,
|
||
|
input=None,
|
||
|
model_chunk=chunk_1,
|
||
|
model_chunk_id=chunk_id,
|
||
|
)
|
||
|
print(
|
||
|
f"chunk 1 id 0 (layer 1)fwd: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};"
|
||
|
)
|
||
|
|
||
|
# chunk 2 id 0 (layer 2) fwd
|
||
|
if rank == 2:
|
||
|
chunk_id = 0
|
||
|
input2, output2, _ = schedule_f(
|
||
|
stage_manager=stage_manager,
|
||
|
comm=comm,
|
||
|
input=None,
|
||
|
model_chunk=chunk_2,
|
||
|
model_chunk_id=chunk_id,
|
||
|
)
|
||
|
print(
|
||
|
f"chunk 2 id 0 (layer 2)fwd: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};"
|
||
|
)
|
||
|
|
||
|
# chunk 3 id 0 (layer 3) fwd
|
||
|
if rank == 3:
|
||
|
chunk_id = 0
|
||
|
input3, output3, _ = schedule_f(
|
||
|
stage_manager=stage_manager,
|
||
|
comm=comm,
|
||
|
input=None,
|
||
|
model_chunk=chunk_3,
|
||
|
model_chunk_id=chunk_id,
|
||
|
)
|
||
|
print(
|
||
|
f"chunk 3 id 0 (layer 3)fwd: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};"
|
||
|
)
|
||
|
|
||
|
######
|
||
|
# fwd 4->1
|
||
|
######
|
||
|
|
||
|
if rank == 3:
|
||
|
chunk_id = 1
|
||
|
input4, output4, _ = schedule_f(
|
||
|
stage_manager=stage_manager,
|
||
|
comm=comm,
|
||
|
input=output3,
|
||
|
model_chunk=chunk_3,
|
||
|
model_chunk_id=chunk_id,
|
||
|
)
|
||
|
print(
|
||
|
f"chunk 3 id 1 (layer 4)fwd: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};"
|
||
|
)
|
||
|
|
||
|
if rank == 2:
|
||
|
chunk_id = 1
|
||
|
input5, output5, _ = schedule_f(
|
||
|
stage_manager=stage_manager,
|
||
|
comm=comm,
|
||
|
input=None,
|
||
|
model_chunk=chunk_2,
|
||
|
model_chunk_id=chunk_id,
|
||
|
)
|
||
|
print(
|
||
|
f"chunk 2 id 1 (layer 5)fwd: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};"
|
||
|
)
|
||
|
|
||
|
if rank == 1:
|
||
|
chunk_id = 1
|
||
|
input6, output6, _ = schedule_f(
|
||
|
stage_manager=stage_manager,
|
||
|
comm=comm,
|
||
|
input=None,
|
||
|
model_chunk=chunk_1,
|
||
|
model_chunk_id=chunk_id,
|
||
|
)
|
||
|
print(
|
||
|
f"chunk 1 id 1 (layer 6)fwd: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};"
|
||
|
)
|
||
|
|
||
|
if rank == 0:
|
||
|
chunk_id = 1
|
||
|
input7, output7, loss = schedule_f(
|
||
|
stage_manager=stage_manager,
|
||
|
comm=comm,
|
||
|
input=None,
|
||
|
model_chunk=chunk_0,
|
||
|
model_chunk_id=chunk_id,
|
||
|
)
|
||
|
# print(f"fwd output {output7}")
|
||
|
print(
|
||
|
f"chunk 0 id 1 (layer 7)fwd: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};"
|
||
|
)
|
||
|
|
||
|
##########################
|
||
|
# Step2: bwd
|
||
|
##########################
|
||
|
######
|
||
|
# bwd rank 4->1
|
||
|
######
|
||
|
# chunk 0 id 1 (layer 7) bwd
|
||
|
if rank == 0:
|
||
|
chunk_id = 1
|
||
|
input_grad7 = schedule_b(
|
||
|
stage_manager=stage_manager,
|
||
|
comm=comm,
|
||
|
input=input7, # x
|
||
|
output=output7, # y
|
||
|
output_grad=loss, # dy
|
||
|
model_chunk=chunk_0,
|
||
|
model_chunk_id=chunk_id,
|
||
|
)
|
||
|
|
||
|
# # chunk 1 id 1 (layer 6) bwd
|
||
|
if rank == 1:
|
||
|
chunk_id = 1
|
||
|
input_grad6 = schedule_b(
|
||
|
stage_manager=stage_manager,
|
||
|
comm=comm,
|
||
|
input=input6, # x
|
||
|
output=output6, # y
|
||
|
output_grad=None, # dy
|
||
|
model_chunk=chunk_1,
|
||
|
model_chunk_id=chunk_id,
|
||
|
)
|
||
|
|
||
|
# chunk 2 id 1 (layer 5) bwd
|
||
|
if rank == 2:
|
||
|
chunk_id = 1
|
||
|
input_grad5 = schedule_b(
|
||
|
stage_manager=stage_manager,
|
||
|
comm=comm,
|
||
|
input=input5, # x
|
||
|
output=output5, # y
|
||
|
output_grad=None, # dy
|
||
|
model_chunk=chunk_2,
|
||
|
model_chunk_id=chunk_id,
|
||
|
)
|
||
|
|
||
|
# chunk 3 id 1 (layer 4) bwd
|
||
|
if rank == 3:
|
||
|
chunk_id = 1
|
||
|
input_grad4 = schedule_b(
|
||
|
stage_manager=stage_manager,
|
||
|
comm=comm,
|
||
|
input=input4, # x
|
||
|
output=output4, # y
|
||
|
output_grad=None, # dy
|
||
|
model_chunk=chunk_3,
|
||
|
model_chunk_id=chunk_id,
|
||
|
)
|
||
|
# print(f"input_grad4 {input_grad4}")
|
||
|
|
||
|
######
|
||
|
# bwd rank 1->4
|
||
|
######
|
||
|
|
||
|
# chunk 3 id 0 (layer 3) bwd
|
||
|
if rank == 3:
|
||
|
chunk_id = 0
|
||
|
input_grad3 = schedule_b(
|
||
|
stage_manager=stage_manager,
|
||
|
comm=comm,
|
||
|
input=input3, # x
|
||
|
output=output3, # y
|
||
|
output_grad=input_grad4, # dy
|
||
|
model_chunk=chunk_3,
|
||
|
model_chunk_id=chunk_id,
|
||
|
)
|
||
|
# print(f"input_grad3 {input_grad3}")
|
||
|
|
||
|
# chunk 2 id 0 (layer 2) bwd
|
||
|
if rank == 2:
|
||
|
chunk_id = 0
|
||
|
input_grad2 = schedule_b(
|
||
|
stage_manager=stage_manager,
|
||
|
comm=comm,
|
||
|
input=input2, # x
|
||
|
output=output2, # y
|
||
|
output_grad=None, # dy
|
||
|
model_chunk=chunk_2,
|
||
|
model_chunk_id=chunk_id,
|
||
|
)
|
||
|
# print(f"input_grad2 {input_grad2}")
|
||
|
|
||
|
# chunk 1 id 0 (layer 1) bwd
|
||
|
if rank == 1:
|
||
|
chunk_id = 0
|
||
|
input_grad1 = schedule_b(
|
||
|
stage_manager=stage_manager,
|
||
|
comm=comm,
|
||
|
input=input1, # x
|
||
|
output=output1, # y
|
||
|
output_grad=None, # dy
|
||
|
model_chunk=chunk_1,
|
||
|
model_chunk_id=chunk_id,
|
||
|
)
|
||
|
|
||
|
# chunk 0 id 0 (layer 0) bwd
|
||
|
if rank == 0:
|
||
|
chunk_id = 0
|
||
|
input_grad0 = schedule_b(
|
||
|
stage_manager=stage_manager,
|
||
|
comm=comm,
|
||
|
input=input0, # x
|
||
|
output=output0, # y
|
||
|
output_grad=None, # dy
|
||
|
model_chunk=chunk_0,
|
||
|
model_chunk_id=chunk_id,
|
||
|
)
|
||
|
# print(f"input_grad0 {input_grad0}")
|
||
|
|
||
|
##########################
|
||
|
# Fwd bwd for base
|
||
|
##########################
|
||
|
# fwd & bwd
|
||
|
output_base = model_base(input_base)
|
||
|
loss_base = output_base.mean()
|
||
|
loss_base.backward()
|
||
|
print(f"After base fwd & bwd: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
|
||
|
|
||
|
##########################
|
||
|
# Assert close
|
||
|
##########################
|
||
|
# assert output
|
||
|
if rank == 0:
|
||
|
assert_close(output7, output_base)
|
||
|
|
||
|
# assert weight
|
||
|
if rank == 0:
|
||
|
# layer 0
|
||
|
assert_close(chunk_0[0].weight, model_base.layers[0].weight)
|
||
|
assert_close(chunk_0[0].weight.grad, model_base.layers[0].weight.grad)
|
||
|
# layer 7
|
||
|
assert_close(chunk_0[1].weight, model_base.layers[7].weight)
|
||
|
assert_close(chunk_0[1].weight.grad, model_base.layers[7].weight.grad)
|
||
|
if rank == 1:
|
||
|
# layer 1
|
||
|
assert_close(chunk_1[0].weight, model_base.layers[1].weight)
|
||
|
assert_close(chunk_1[0].weight.grad, model_base.layers[1].weight.grad)
|
||
|
# layer 6
|
||
|
assert_close(chunk_1[1].weight, model_base.layers[6].weight)
|
||
|
assert_close(chunk_1[1].weight.grad, model_base.layers[6].weight.grad)
|
||
|
|
||
|
if rank == 2:
|
||
|
# layer 2
|
||
|
assert_close(chunk_2[0].weight, model_base.layers[2].weight)
|
||
|
assert_close(chunk_2[0].weight.grad, model_base.layers[2].weight.grad)
|
||
|
# layer 5
|
||
|
assert_close(chunk_2[1].weight, model_base.layers[5].weight)
|
||
|
assert_close(chunk_2[1].weight.grad, model_base.layers[5].weight.grad)
|
||
|
|
||
|
if rank == 3:
|
||
|
# layer 3
|
||
|
assert_close(chunk_3[0].weight, model_base.layers[3].weight)
|
||
|
assert_close(chunk_3[0].weight.grad, model_base.layers[3].weight.grad)
|
||
|
# layer 4
|
||
|
assert_close(chunk_3[1].weight, model_base.layers[4].weight)
|
||
|
assert_close(chunk_3[1].weight.grad, model_base.layers[4].weight.grad)
|
||
|
|
||
|
# clean memory
|
||
|
if rank == 0:
|
||
|
del input0, output0, input_grad0, input7, output7, input_grad7, loss
|
||
|
if rank == 1:
|
||
|
del input1, output1, input_grad1, input6, output6, input_grad6
|
||
|
if rank == 2:
|
||
|
del input2, output2, input_grad2, input5, output5, input_grad5
|
||
|
if rank == 3:
|
||
|
del input3, output3, input_grad3, input4, output4, input_grad4
|
||
|
# print(f"After del device: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};")
|
||
|
|
||
|
del loss_base, output_base
|
||
|
|
||
|
print(f"After del: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};")
|
||
|
|
||
|
|
||
|
@rerun_if_address_is_in_use()
|
||
|
def test_dx_dw_dist():
|
||
|
spawn(
|
||
|
model_chunk_dx_dw_comm_interleaved,
|
||
|
nprocs=4,
|
||
|
)
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
# test_dx_dw_split()
|
||
|
# test_double_dx_dw_split_nsync()
|
||
|
# test_double_dx_dw_split_sync()
|
||
|
# mem_dx_dw()
|
||
|
# activation_dx_dw()
|
||
|
# model_chunk_dx_dw()
|
||
|
|
||
|
test_dx_dw_dist()
|