mirror of https://github.com/hpcaitech/ColossalAI
[feat] fix poc format
parent
d6e3d7d2a3
commit
b5f7b4d228
|
@ -1,6 +1,5 @@
|
|||
import gc
|
||||
from copy import deepcopy
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
@ -13,11 +12,13 @@ from colossalai.pipeline.p2p import PipelineP2PCommunication
|
|||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
|
||||
# info of model
|
||||
IN_DIM = 8192
|
||||
OUT_DIM = 8192
|
||||
NUM_LAYER = 3
|
||||
|
||||
|
||||
# A simple MLP
|
||||
class MlpModel(nn.Module):
|
||||
def __init__(self, in_dim=IN_DIM, out_dim=OUT_DIM, num_layers=NUM_LAYER):
|
||||
super().__init__()
|
||||
|
@ -29,29 +30,10 @@ class MlpModel(nn.Module):
|
|||
return x
|
||||
|
||||
|
||||
def get_model_numel(model: torch.nn.Module) -> Tuple[int, int]:
|
||||
num_params = 0
|
||||
num_params_trainable = 0
|
||||
for p in model.parameters():
|
||||
num_params += p.numel()
|
||||
if p.requires_grad:
|
||||
num_params_trainable += p.numel()
|
||||
return num_params, num_params_trainable
|
||||
|
||||
|
||||
# Step1: dx = w*dy
|
||||
def backward_b(loss, x, model):
|
||||
print(f"Before bwd b: {torch.cuda.memory_allocated()/1024**3 :.3f} GB")
|
||||
# print(f"Before x grad {x.grad}")
|
||||
# for name, param in model.named_parameters():
|
||||
# print(f"Before bwd b \n param {param}\n param gard {param.grad}\n")
|
||||
|
||||
torch.autograd.backward(loss, inputs=x, retain_graph=True)
|
||||
|
||||
# for name, param in model.named_parameters():
|
||||
# print(f"After bwd b \n param {param}\n param gard {param.grad}\n")
|
||||
|
||||
# print(f"After x grad {x.grad}")
|
||||
print(f"After bwd b: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
|
||||
|
||||
|
||||
|
@ -64,15 +46,7 @@ def backward_b_not_last(tensors, grad, x, model):
|
|||
|
||||
def backward_w(loss, model):
|
||||
print(f"Before bwd w: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
|
||||
|
||||
# for name, param in model.named_parameters():
|
||||
# print(f"Before bwd w \n param {param}\n param gard {param.grad}\n")
|
||||
|
||||
torch.autograd.backward(loss, inputs=list(model.parameters()))
|
||||
|
||||
# for name, param in model.named_parameters():
|
||||
# print(f"After bwd w \n param {param}\n param gard {param.grad}\n")
|
||||
|
||||
print(f"After bwd w: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
|
||||
|
||||
|
||||
|
@ -83,6 +57,7 @@ def backward_w_not_last(tensors, grad, model):
|
|||
print(f"After bwd w: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
|
||||
|
||||
|
||||
# In this poc, we check feasibility of spliting dx and dw in bwd propagation
|
||||
def test_dx_dw_split():
|
||||
device = "cuda:0"
|
||||
model = nn.Linear(8, 8, bias=None).to(device=device)
|
||||
|
@ -116,6 +91,8 @@ def test_dx_dw_split():
|
|||
assert torch.equal(p1.grad, p2.grad)
|
||||
|
||||
|
||||
# In this poc, we check nsync of spliting dx and dw in bwd propagation in following order:
|
||||
# fwd1 --> fwd2 --> dx1 --> dx2 --> dw1 --> dw2
|
||||
def test_double_dx_dw_split_nsync():
|
||||
device = "cuda:0"
|
||||
model = nn.Linear(8, 8, bias=None).to(device=device)
|
||||
|
@ -177,16 +154,14 @@ def test_double_dx_dw_split_nsync():
|
|||
assert_close(p1.grad, p2.grad)
|
||||
|
||||
|
||||
# In this poc, we check sync of spliting dx and dw in bwd propagation in following order:
|
||||
# fwd1 --> fwd2 --> dx1 --> dw1 --> dx2 --> dw2
|
||||
def test_double_dx_dw_split_sync():
|
||||
device = "cuda:0"
|
||||
model = nn.Linear(8, 8, bias=None).to(device=device)
|
||||
# print(f"model numel {get_model_numel(model)}") # 4GB
|
||||
x1 = torch.rand(8, 8).to(device=device)
|
||||
x2 = torch.rand(8, 8).to(device=device)
|
||||
|
||||
# x1 = torch.ones(8, 8).to(device=device)
|
||||
# x2 = torch.ones(8, 8).to(device=device)
|
||||
|
||||
ref_model = deepcopy(model)
|
||||
ref_x1 = x1.clone()
|
||||
ref_x2 = x2.clone()
|
||||
|
@ -239,7 +214,6 @@ def test_double_dx_dw_split_sync():
|
|||
ref_loss2 = ref_model(ref_x2).sum()
|
||||
|
||||
for p1, p2 in zip(model.parameters(), ref_model.parameters()):
|
||||
# print(f"bwd2:\n p1 {p1.grad},\n p2 {p2.grad}\n")
|
||||
assert_close(p1, p2)
|
||||
assert_close(p1.grad, p2.grad)
|
||||
|
||||
|
@ -255,31 +229,13 @@ def test_double_dx_dw_split_sync():
|
|||
# assert dx2 & dw2 == bwd 2
|
||||
assert_close(x2.grad, ref_x2.grad)
|
||||
for p1, p2 in zip(model.parameters(), ref_model.parameters()):
|
||||
# print(f"bwd2:\n p1 {p1.grad},\n p2 {p2.grad}\n")
|
||||
assert_close(p1, p2)
|
||||
assert_close(p1.grad, p2.grad)
|
||||
|
||||
|
||||
def deallocate_output_tensor(out):
|
||||
"""Pseudo-deallocate (i.e., set to scalar) the output tensor's '.data' field.
|
||||
|
||||
This method should be called right after the output tensor has been
|
||||
sent to the next pipeline stage. At this point, the output tensor is
|
||||
only useful for its '.grad_fn' field, and not its '.data'.
|
||||
"""
|
||||
assert isinstance(out, torch.Tensor), "expected Tensor, found %s." % type(out).__name__
|
||||
assert out._base is None, "counter-productive to free a view of another tensor."
|
||||
out.data = torch.empty(
|
||||
(1,),
|
||||
device=out.device,
|
||||
dtype=out.dtype,
|
||||
)
|
||||
|
||||
|
||||
# del loss and x
|
||||
# In this poc, we check if a memory leak has occurred after del input & loss(with graph)
|
||||
def mem_dx_dw():
|
||||
device = "cuda:0"
|
||||
# model = nn.Linear(IN_DIM, OUT_DIM, bias=None).to(device=device)
|
||||
print(f"Before init Model: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
|
||||
model = MlpModel().to(device=device)
|
||||
print(f"model numel {get_model_numel(model)}") # 4GB
|
||||
|
@ -314,8 +270,6 @@ def mem_dx_dw():
|
|||
# dw1
|
||||
backward_w(loss1, model)
|
||||
|
||||
# deallocate_output_tensor(x1)
|
||||
# deallocate_output_tensor(loss1)
|
||||
del loss1, x1
|
||||
# del x1
|
||||
# del y1
|
||||
|
@ -335,8 +289,6 @@ def mem_dx_dw():
|
|||
# dw2
|
||||
backward_w(loss2, model)
|
||||
|
||||
# deallocate_output_tensor(x2)
|
||||
# deallocate_output_tensor(loss2)
|
||||
del x2, loss2
|
||||
# del x2
|
||||
# del y2
|
||||
|
@ -356,8 +308,6 @@ def mem_dx_dw():
|
|||
# dw2
|
||||
backward_w(loss3, model)
|
||||
|
||||
# deallocate_output_tensor(x3)
|
||||
# deallocate_output_tensor(loss3)
|
||||
# del x3
|
||||
# del y3
|
||||
del x3, loss3
|
||||
|
@ -370,7 +320,7 @@ def mem_dx_dw():
|
|||
print(obj)
|
||||
|
||||
|
||||
# del activation
|
||||
# In this poc, we check if a memory leak has occurred after del input & loss(with graph) & activation
|
||||
def activation_dx_dw():
|
||||
device = "cuda:0"
|
||||
# model = nn.Linear(IN_DIM, OUT_DIM, bias=None).to(device=device)
|
||||
|
@ -385,17 +335,6 @@ def activation_dx_dw():
|
|||
x3.requires_grad_()
|
||||
print(f"After init Model, x1,x2,x3: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
|
||||
|
||||
# activations = {}
|
||||
# def register_hooks(module):
|
||||
# def activation_hook(module, input, output):
|
||||
# activations[f"{module.__class__.__name__}_{id(module)}"] = output.detach()
|
||||
# def bwd_hook(module, grad_input, grad_output):
|
||||
# del activations[f"{module.__class__.__name__}_{id(module)}"]
|
||||
# module.register_forward_hook(activation_hook)
|
||||
# module.register_backward_hook(bwd_hook)
|
||||
|
||||
# model.apply(register_hooks)
|
||||
|
||||
############
|
||||
# step1:
|
||||
############
|
||||
|
@ -408,15 +347,9 @@ def activation_dx_dw():
|
|||
# dx1
|
||||
backward_b(loss1, x1, model)
|
||||
|
||||
# for name, p in model.named_parameters():
|
||||
# print(f"p grad {p.grad}")
|
||||
|
||||
# dw1
|
||||
backward_w(loss1, model)
|
||||
|
||||
# for name, p in model.named_parameters():
|
||||
# del p.grad
|
||||
|
||||
# del loss1, x1
|
||||
del loss1, x1, output1
|
||||
print(f"After del : {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
|
||||
|
@ -433,15 +366,9 @@ def activation_dx_dw():
|
|||
# dx2
|
||||
backward_b(loss2, x2, model)
|
||||
|
||||
# for name, p in model.named_parameters():
|
||||
# print(f"p grad {p.grad}")
|
||||
|
||||
# dw2
|
||||
backward_w(loss2, model)
|
||||
|
||||
# for name, p in model.named_parameters():
|
||||
# print(f"p grad {p.grad}")
|
||||
|
||||
# del x2, loss2
|
||||
del x2, loss2, output2
|
||||
print(f"After del : {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
|
||||
|
@ -467,6 +394,7 @@ def activation_dx_dw():
|
|||
print(f"After del : {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
|
||||
|
||||
|
||||
# In this poc, we apply model chunk instead of layer
|
||||
def model_chunk_dx_dw():
|
||||
device = "cuda:0"
|
||||
num_layers = 4
|
||||
|
@ -555,6 +483,7 @@ def model_chunk_dx_dw():
|
|||
print(f"After del: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
|
||||
|
||||
|
||||
# In this poc, we apply model chunk and a pp group for communication
|
||||
def model_chunk_dx_dw_communication(
|
||||
rank: int,
|
||||
world_size: int,
|
||||
|
@ -598,9 +527,6 @@ def model_chunk_dx_dw_communication(
|
|||
##########################
|
||||
if rank == 0:
|
||||
output1 = model_chunk_0(input)
|
||||
# detach output1; then output1 for chunk 0, output1_dt for chunk 1;
|
||||
# output1_dt_rank0 = output1.detach()
|
||||
# output1_dt_rank0.requires_grad_()
|
||||
print(
|
||||
f"After chunk0 fwd (include detach output1): {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};"
|
||||
)
|
||||
|
@ -689,7 +615,7 @@ def model_chunk_dx_dw_communication(
|
|||
print(f"After del: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};")
|
||||
|
||||
|
||||
# Return: output, loss
|
||||
# fwd schedule
|
||||
def schedule_f(
|
||||
stage_manager: PipelineStageManager,
|
||||
comm: PipelineP2PCommunication,
|
||||
|
@ -738,6 +664,7 @@ def schedule_f(
|
|||
return input, output, None
|
||||
|
||||
|
||||
# bwd b schedule
|
||||
def schedule_b(
|
||||
stage_manager: PipelineStageManager,
|
||||
comm: PipelineP2PCommunication,
|
||||
|
@ -759,7 +686,6 @@ def schedule_b(
|
|||
|
||||
# bwd step
|
||||
backward_b_not_last(tensors=output, grad=output_grad, x=input, model=model_chunk[model_chunk_id])
|
||||
|
||||
backward_w_not_last(tensors=output, grad=output_grad, model=model_chunk[model_chunk_id])
|
||||
|
||||
# send bwd to prev
|
||||
|
@ -776,27 +702,17 @@ def schedule_b(
|
|||
output_grad = output_grad
|
||||
else:
|
||||
prev_rank = stage_manager.get_prev_rank()
|
||||
# print(f"prev_rank {prev_rank} curr rank {stage_manager.get_rank()}")
|
||||
output_grad, _ = comm.recv_backward(next_rank=prev_rank)
|
||||
|
||||
# bwd step
|
||||
# print(f"Before input grad {input.grad}")
|
||||
# for name, param in model_chunk[model_chunk_id].named_parameters():
|
||||
# print(f"Before {name} grad {param.grad}")
|
||||
|
||||
if stage_manager.is_first_stage(ignore_chunk=True):
|
||||
backward_b(loss=output_grad, x=input, model=model_chunk[model_chunk_id])
|
||||
backward_w(loss=output_grad, model=model_chunk[model_chunk_id])
|
||||
else:
|
||||
# commom bwd step
|
||||
# print(f"output_grad {output_grad}")
|
||||
backward_b_not_last(tensors=output, grad=output_grad, x=input, model=model_chunk[model_chunk_id])
|
||||
backward_w_not_last(tensors=output, grad=output_grad, model=model_chunk[model_chunk_id])
|
||||
|
||||
# print(f"After input grad {input.grad}")
|
||||
# for name, param in model_chunk[model_chunk_id].named_parameters():
|
||||
# print(f"After {name} grad {param.grad}")
|
||||
|
||||
# send bwd to next
|
||||
if stage_manager.is_last_stage(ignore_chunk=True):
|
||||
return input.grad
|
||||
|
@ -807,10 +723,12 @@ def schedule_b(
|
|||
return input.grad
|
||||
|
||||
|
||||
# bwd w schedule (dw already splite in schedule b)
|
||||
def schedule_w():
|
||||
pass
|
||||
|
||||
|
||||
# In this poc, we apply a scheduling method for each rank: schedule_f --> schedule_b --> schedule_w
|
||||
def model_chunk_dx_dw_comm_interleaved(
|
||||
rank: int,
|
||||
world_size: int,
|
||||
|
@ -858,21 +776,9 @@ def model_chunk_dx_dw_comm_interleaved(
|
|||
if idx == 3 or idx == 4:
|
||||
chunk_3.append(sub_model)
|
||||
|
||||
# # test checkpoint
|
||||
# check_fn = lambda submodule: isinstance(submodule, (Linear))
|
||||
# non_reentrant_wrapper = partial(
|
||||
# checkpoint_wrapper,
|
||||
# # checkpoint_impl=CheckpointImpl.NO_REENTRANT,
|
||||
# checkpoint_impl=CheckpointImpl.REENTRANT,
|
||||
# )
|
||||
# apply_activation_checkpointing(
|
||||
# model, checkpoint_wrapper_fn=non_reentrant_wrapper, check_fn=check_fn
|
||||
# )
|
||||
|
||||
print(
|
||||
f"After init Model & input: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};"
|
||||
)
|
||||
# set_checkpoint_early_stop(False)
|
||||
# buffer use to save input and output
|
||||
|
||||
##########################
|
||||
|
@ -1051,7 +957,6 @@ def model_chunk_dx_dw_comm_interleaved(
|
|||
model_chunk=chunk_3,
|
||||
model_chunk_id=chunk_id,
|
||||
)
|
||||
# print(f"input_grad4 {input_grad4}")
|
||||
|
||||
######
|
||||
# bwd rank 1->4
|
||||
|
@ -1069,7 +974,6 @@ def model_chunk_dx_dw_comm_interleaved(
|
|||
model_chunk=chunk_3,
|
||||
model_chunk_id=chunk_id,
|
||||
)
|
||||
# print(f"input_grad3 {input_grad3}")
|
||||
|
||||
# chunk 2 id 0 (layer 2) bwd
|
||||
if rank == 2:
|
||||
|
@ -1083,7 +987,6 @@ def model_chunk_dx_dw_comm_interleaved(
|
|||
model_chunk=chunk_2,
|
||||
model_chunk_id=chunk_id,
|
||||
)
|
||||
# print(f"input_grad2 {input_grad2}")
|
||||
|
||||
# chunk 1 id 0 (layer 1) bwd
|
||||
if rank == 1:
|
||||
|
@ -1110,7 +1013,6 @@ def model_chunk_dx_dw_comm_interleaved(
|
|||
model_chunk=chunk_0,
|
||||
model_chunk_id=chunk_id,
|
||||
)
|
||||
# print(f"input_grad0 {input_grad0}")
|
||||
|
||||
##########################
|
||||
# Fwd bwd for base
|
||||
|
@ -1169,8 +1071,6 @@ def model_chunk_dx_dw_comm_interleaved(
|
|||
del input2, output2, input_grad2, input5, output5, input_grad5
|
||||
if rank == 3:
|
||||
del input3, output3, input_grad3, input4, output4, input_grad4
|
||||
# print(f"After del device: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};")
|
||||
|
||||
del loss_base, output_base
|
||||
|
||||
print(f"After del: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};")
|
||||
|
@ -1185,11 +1085,4 @@ def test_dx_dw_dist():
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# test_dx_dw_split()
|
||||
# test_double_dx_dw_split_nsync()
|
||||
# test_double_dx_dw_split_sync()
|
||||
# mem_dx_dw()
|
||||
# activation_dx_dw()
|
||||
# model_chunk_dx_dw()
|
||||
|
||||
test_dx_dw_dist()
|
||||
|
|
Loading…
Reference in New Issue