[feat] fix poc format

pull/6034/head
duanjunwen 2024-08-28 03:08:35 +00:00
parent d6e3d7d2a3
commit b5f7b4d228
1 changed files with 15 additions and 122 deletions

View File

@ -1,6 +1,5 @@
import gc
from copy import deepcopy
from typing import Tuple
import torch
import torch.distributed as dist
@ -13,11 +12,13 @@ from colossalai.pipeline.p2p import PipelineP2PCommunication
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.testing import rerun_if_address_is_in_use, spawn
# info of model
IN_DIM = 8192
OUT_DIM = 8192
NUM_LAYER = 3
# A simple MLP
class MlpModel(nn.Module):
def __init__(self, in_dim=IN_DIM, out_dim=OUT_DIM, num_layers=NUM_LAYER):
super().__init__()
@ -29,29 +30,10 @@ class MlpModel(nn.Module):
return x
def get_model_numel(model: torch.nn.Module) -> Tuple[int, int]:
num_params = 0
num_params_trainable = 0
for p in model.parameters():
num_params += p.numel()
if p.requires_grad:
num_params_trainable += p.numel()
return num_params, num_params_trainable
# Step1: dx = w*dy
def backward_b(loss, x, model):
print(f"Before bwd b: {torch.cuda.memory_allocated()/1024**3 :.3f} GB")
# print(f"Before x grad {x.grad}")
# for name, param in model.named_parameters():
# print(f"Before bwd b \n param {param}\n param gard {param.grad}\n")
torch.autograd.backward(loss, inputs=x, retain_graph=True)
# for name, param in model.named_parameters():
# print(f"After bwd b \n param {param}\n param gard {param.grad}\n")
# print(f"After x grad {x.grad}")
print(f"After bwd b: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
@ -64,15 +46,7 @@ def backward_b_not_last(tensors, grad, x, model):
def backward_w(loss, model):
print(f"Before bwd w: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
# for name, param in model.named_parameters():
# print(f"Before bwd w \n param {param}\n param gard {param.grad}\n")
torch.autograd.backward(loss, inputs=list(model.parameters()))
# for name, param in model.named_parameters():
# print(f"After bwd w \n param {param}\n param gard {param.grad}\n")
print(f"After bwd w: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
@ -83,6 +57,7 @@ def backward_w_not_last(tensors, grad, model):
print(f"After bwd w: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
# In this poc, we check feasibility of spliting dx and dw in bwd propagation
def test_dx_dw_split():
device = "cuda:0"
model = nn.Linear(8, 8, bias=None).to(device=device)
@ -116,6 +91,8 @@ def test_dx_dw_split():
assert torch.equal(p1.grad, p2.grad)
# In this poc, we check nsync of spliting dx and dw in bwd propagation in following order:
# fwd1 --> fwd2 --> dx1 --> dx2 --> dw1 --> dw2
def test_double_dx_dw_split_nsync():
device = "cuda:0"
model = nn.Linear(8, 8, bias=None).to(device=device)
@ -177,16 +154,14 @@ def test_double_dx_dw_split_nsync():
assert_close(p1.grad, p2.grad)
# In this poc, we check sync of spliting dx and dw in bwd propagation in following order:
# fwd1 --> fwd2 --> dx1 --> dw1 --> dx2 --> dw2
def test_double_dx_dw_split_sync():
device = "cuda:0"
model = nn.Linear(8, 8, bias=None).to(device=device)
# print(f"model numel {get_model_numel(model)}") # 4GB
x1 = torch.rand(8, 8).to(device=device)
x2 = torch.rand(8, 8).to(device=device)
# x1 = torch.ones(8, 8).to(device=device)
# x2 = torch.ones(8, 8).to(device=device)
ref_model = deepcopy(model)
ref_x1 = x1.clone()
ref_x2 = x2.clone()
@ -239,7 +214,6 @@ def test_double_dx_dw_split_sync():
ref_loss2 = ref_model(ref_x2).sum()
for p1, p2 in zip(model.parameters(), ref_model.parameters()):
# print(f"bwd2:\n p1 {p1.grad},\n p2 {p2.grad}\n")
assert_close(p1, p2)
assert_close(p1.grad, p2.grad)
@ -255,31 +229,13 @@ def test_double_dx_dw_split_sync():
# assert dx2 & dw2 == bwd 2
assert_close(x2.grad, ref_x2.grad)
for p1, p2 in zip(model.parameters(), ref_model.parameters()):
# print(f"bwd2:\n p1 {p1.grad},\n p2 {p2.grad}\n")
assert_close(p1, p2)
assert_close(p1.grad, p2.grad)
def deallocate_output_tensor(out):
"""Pseudo-deallocate (i.e., set to scalar) the output tensor's '.data' field.
This method should be called right after the output tensor has been
sent to the next pipeline stage. At this point, the output tensor is
only useful for its '.grad_fn' field, and not its '.data'.
"""
assert isinstance(out, torch.Tensor), "expected Tensor, found %s." % type(out).__name__
assert out._base is None, "counter-productive to free a view of another tensor."
out.data = torch.empty(
(1,),
device=out.device,
dtype=out.dtype,
)
# del loss and x
# In this poc, we check if a memory leak has occurred after del input & loss(with graph)
def mem_dx_dw():
device = "cuda:0"
# model = nn.Linear(IN_DIM, OUT_DIM, bias=None).to(device=device)
print(f"Before init Model: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
model = MlpModel().to(device=device)
print(f"model numel {get_model_numel(model)}") # 4GB
@ -314,8 +270,6 @@ def mem_dx_dw():
# dw1
backward_w(loss1, model)
# deallocate_output_tensor(x1)
# deallocate_output_tensor(loss1)
del loss1, x1
# del x1
# del y1
@ -335,8 +289,6 @@ def mem_dx_dw():
# dw2
backward_w(loss2, model)
# deallocate_output_tensor(x2)
# deallocate_output_tensor(loss2)
del x2, loss2
# del x2
# del y2
@ -356,8 +308,6 @@ def mem_dx_dw():
# dw2
backward_w(loss3, model)
# deallocate_output_tensor(x3)
# deallocate_output_tensor(loss3)
# del x3
# del y3
del x3, loss3
@ -370,7 +320,7 @@ def mem_dx_dw():
print(obj)
# del activation
# In this poc, we check if a memory leak has occurred after del input & loss(with graph) & activation
def activation_dx_dw():
device = "cuda:0"
# model = nn.Linear(IN_DIM, OUT_DIM, bias=None).to(device=device)
@ -385,17 +335,6 @@ def activation_dx_dw():
x3.requires_grad_()
print(f"After init Model, x1,x2,x3: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
# activations = {}
# def register_hooks(module):
# def activation_hook(module, input, output):
# activations[f"{module.__class__.__name__}_{id(module)}"] = output.detach()
# def bwd_hook(module, grad_input, grad_output):
# del activations[f"{module.__class__.__name__}_{id(module)}"]
# module.register_forward_hook(activation_hook)
# module.register_backward_hook(bwd_hook)
# model.apply(register_hooks)
############
# step1:
############
@ -408,15 +347,9 @@ def activation_dx_dw():
# dx1
backward_b(loss1, x1, model)
# for name, p in model.named_parameters():
# print(f"p grad {p.grad}")
# dw1
backward_w(loss1, model)
# for name, p in model.named_parameters():
# del p.grad
# del loss1, x1
del loss1, x1, output1
print(f"After del : {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
@ -433,15 +366,9 @@ def activation_dx_dw():
# dx2
backward_b(loss2, x2, model)
# for name, p in model.named_parameters():
# print(f"p grad {p.grad}")
# dw2
backward_w(loss2, model)
# for name, p in model.named_parameters():
# print(f"p grad {p.grad}")
# del x2, loss2
del x2, loss2, output2
print(f"After del : {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
@ -467,6 +394,7 @@ def activation_dx_dw():
print(f"After del : {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
# In this poc, we apply model chunk instead of layer
def model_chunk_dx_dw():
device = "cuda:0"
num_layers = 4
@ -555,6 +483,7 @@ def model_chunk_dx_dw():
print(f"After del: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
# In this poc, we apply model chunk and a pp group for communication
def model_chunk_dx_dw_communication(
rank: int,
world_size: int,
@ -598,9 +527,6 @@ def model_chunk_dx_dw_communication(
##########################
if rank == 0:
output1 = model_chunk_0(input)
# detach output1; then output1 for chunk 0, output1_dt for chunk 1;
# output1_dt_rank0 = output1.detach()
# output1_dt_rank0.requires_grad_()
print(
f"After chunk0 fwd (include detach output1): {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};"
)
@ -689,7 +615,7 @@ def model_chunk_dx_dw_communication(
print(f"After del: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};")
# Return: output, loss
# fwd schedule
def schedule_f(
stage_manager: PipelineStageManager,
comm: PipelineP2PCommunication,
@ -738,6 +664,7 @@ def schedule_f(
return input, output, None
# bwd b schedule
def schedule_b(
stage_manager: PipelineStageManager,
comm: PipelineP2PCommunication,
@ -759,7 +686,6 @@ def schedule_b(
# bwd step
backward_b_not_last(tensors=output, grad=output_grad, x=input, model=model_chunk[model_chunk_id])
backward_w_not_last(tensors=output, grad=output_grad, model=model_chunk[model_chunk_id])
# send bwd to prev
@ -776,27 +702,17 @@ def schedule_b(
output_grad = output_grad
else:
prev_rank = stage_manager.get_prev_rank()
# print(f"prev_rank {prev_rank} curr rank {stage_manager.get_rank()}")
output_grad, _ = comm.recv_backward(next_rank=prev_rank)
# bwd step
# print(f"Before input grad {input.grad}")
# for name, param in model_chunk[model_chunk_id].named_parameters():
# print(f"Before {name} grad {param.grad}")
if stage_manager.is_first_stage(ignore_chunk=True):
backward_b(loss=output_grad, x=input, model=model_chunk[model_chunk_id])
backward_w(loss=output_grad, model=model_chunk[model_chunk_id])
else:
# commom bwd step
# print(f"output_grad {output_grad}")
backward_b_not_last(tensors=output, grad=output_grad, x=input, model=model_chunk[model_chunk_id])
backward_w_not_last(tensors=output, grad=output_grad, model=model_chunk[model_chunk_id])
# print(f"After input grad {input.grad}")
# for name, param in model_chunk[model_chunk_id].named_parameters():
# print(f"After {name} grad {param.grad}")
# send bwd to next
if stage_manager.is_last_stage(ignore_chunk=True):
return input.grad
@ -807,10 +723,12 @@ def schedule_b(
return input.grad
# bwd w schedule (dw already splite in schedule b)
def schedule_w():
pass
# In this poc, we apply a scheduling method for each rank: schedule_f --> schedule_b --> schedule_w
def model_chunk_dx_dw_comm_interleaved(
rank: int,
world_size: int,
@ -858,21 +776,9 @@ def model_chunk_dx_dw_comm_interleaved(
if idx == 3 or idx == 4:
chunk_3.append(sub_model)
# # test checkpoint
# check_fn = lambda submodule: isinstance(submodule, (Linear))
# non_reentrant_wrapper = partial(
# checkpoint_wrapper,
# # checkpoint_impl=CheckpointImpl.NO_REENTRANT,
# checkpoint_impl=CheckpointImpl.REENTRANT,
# )
# apply_activation_checkpointing(
# model, checkpoint_wrapper_fn=non_reentrant_wrapper, check_fn=check_fn
# )
print(
f"After init Model & input: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};"
)
# set_checkpoint_early_stop(False)
# buffer use to save input and output
##########################
@ -1051,7 +957,6 @@ def model_chunk_dx_dw_comm_interleaved(
model_chunk=chunk_3,
model_chunk_id=chunk_id,
)
# print(f"input_grad4 {input_grad4}")
######
# bwd rank 1->4
@ -1069,7 +974,6 @@ def model_chunk_dx_dw_comm_interleaved(
model_chunk=chunk_3,
model_chunk_id=chunk_id,
)
# print(f"input_grad3 {input_grad3}")
# chunk 2 id 0 (layer 2) bwd
if rank == 2:
@ -1083,7 +987,6 @@ def model_chunk_dx_dw_comm_interleaved(
model_chunk=chunk_2,
model_chunk_id=chunk_id,
)
# print(f"input_grad2 {input_grad2}")
# chunk 1 id 0 (layer 1) bwd
if rank == 1:
@ -1110,7 +1013,6 @@ def model_chunk_dx_dw_comm_interleaved(
model_chunk=chunk_0,
model_chunk_id=chunk_id,
)
# print(f"input_grad0 {input_grad0}")
##########################
# Fwd bwd for base
@ -1169,8 +1071,6 @@ def model_chunk_dx_dw_comm_interleaved(
del input2, output2, input_grad2, input5, output5, input_grad5
if rank == 3:
del input3, output3, input_grad3, input4, output4, input_grad4
# print(f"After del device: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};")
del loss_base, output_base
print(f"After del: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};")
@ -1185,11 +1085,4 @@ def test_dx_dw_dist():
if __name__ == "__main__":
# test_dx_dw_split()
# test_double_dx_dw_split_nsync()
# test_double_dx_dw_split_sync()
# mem_dx_dw()
# activation_dx_dw()
# model_chunk_dx_dw()
test_dx_dw_dist()