import gc import time from copy import deepcopy import torch import torch.nn as nn from torch.testing import assert_close def get_model_numel(model): return sum(p.numel() for p in model.parameters()) / 1024**2 # Step1: dx = w*dy def backward_b(loss, x, model): torch.autograd.backward(loss, inputs=x, retain_graph=True) # Step2: dummy dw = x*dy def backward_w(loss, model): torch.autograd.backward(loss, inputs=list(model.parameters())) def test_double_dx_dw_split_nsync(): device = "cuda:0" model = nn.Linear(4096, 4096, bias=None).to(device=device) # print(f"model numel {get_model_numel(model)}") # 4GB x1 = torch.rand(4096, 4096).to(device=device) x2 = torch.rand(4096, 4096).to(device=device) ref_model = deepcopy(model) ref_x1 = x1.clone() ref_x2 = x1.clone() # first step x1.requires_grad_() x2.requires_grad_() ref_x1.requires_grad_() ref_x2.requires_grad_() # loss for dx_dw bwd loss1 = model(x1).sum() loss2 = model(x2).sum() # loss for common bwd ref_loss1 = ref_model(ref_x1).sum() ref_loss2 = ref_model(ref_x2).sum() # dx1 torch.cuda.synchronize() bwd_b_start_time = time.time() backward_b(loss1, x1, model) bwd_b_end_time = time.time() print(f"loss_1 bwd B runtime {bwd_b_end_time - bwd_b_start_time}") for p in model.parameters(): assert p.grad is None assert x1.grad is not None # dx2 torch.cuda.synchronize() bwd_b_start_time = time.time() backward_b(loss2, x2, model) bwd_b_end_time = time.time() print(f"loss_2 bwd B runtime {bwd_b_end_time - bwd_b_start_time}") # dw1 torch.cuda.synchronize() bwd_w_start_time = time.time() backward_w(loss1, model) bwd_w_end_time = time.time() print(f"loss_1 bwd W runtime {bwd_w_end_time - bwd_w_start_time}") for p in model.parameters(): assert p.grad is not None # common bwd 1 torch.cuda.synchronize() comm_bwd_start_time = time.time() ref_loss1.backward() comm_bwd_end_time = time.time() print(f"loss_1 comm bwd runtime {comm_bwd_end_time - comm_bwd_start_time}") # # assert dx1 & dw1 == bwd 1 # assert_close(x1.grad, ref_x1.grad) # for p1, p2 in zip(model.parameters(), ref_model.parameters()): # assert_close(p1, p2) # assert_close(p1.grad, p2.grad) # dw2 torch.cuda.synchronize() bwd_w_start_time = time.time() backward_w(loss2, model) bwd_w_end_time = time.time() print(f"loss_2 bwd W runtime {bwd_w_end_time - bwd_w_start_time}") # common bwd 2 torch.cuda.synchronize() comm_bwd_start_time = time.time() ref_loss2.backward() comm_bwd_end_time = time.time() print(f"loss_2 comm bwd runtime {comm_bwd_end_time - comm_bwd_start_time}") # # assert dx2 & dw2 == bwd 2 # assert_close(x2.grad, ref_x2.grad) # for p1, p2 in zip(model.parameters(), ref_model.parameters()): # print(f"bwd2:\n p1 {p1.grad},\n p2 {p2.grad}\n") # assert_close(p1, p2) # assert_close(p1.grad, p2.grad) def test_double_dx_dw_split_sync(): device = "cuda:0" model = nn.Linear(8, 8, bias=None).to(device=device) print(f"model size {get_model_numel(model)} ") # 4GB x1 = torch.rand(8, 8).to(device=device) x2 = torch.rand(8, 8).to(device=device) # x1 = torch.ones(8, 8).to(device=device) # x2 = torch.ones(8, 8).to(device=device) ref_model = deepcopy(model) ref_x1 = x1.clone() ref_x2 = x2.clone() x1.requires_grad_() x2.requires_grad_() ref_x1.requires_grad_() ref_x2.requires_grad_() ############ # step1: ############ # loss1 loss1 = model(x1).sum() # ref_loss1 ref_model(ref_x1).sum() # dx1 backward_b(loss1, x1, model) for p in model.parameters(): assert p.grad is None assert x1.grad is not None # dw1 backward_w(loss1, model) for p in model.parameters(): assert p.grad is not None # common bwd 1 # ref_loss1.backward() # assert dx1 & dw1 == bwd 1 assert_close(x1.grad, ref_x1.grad) for p1, p2 in zip(model.parameters(), ref_model.parameters()): assert_close(p1, p2) assert_close(p1.grad, p2.grad) ############ # step2: ############ # loss2 loss2 = model(x2).sum() # ref_loss2 ref_loss2 = ref_model(ref_x2).sum() for p1, p2 in zip(model.parameters(), ref_model.parameters()): print(f"bwd2:\n p1 {p1.grad},\n p2 {p2.grad}\n") assert_close(p1, p2) assert_close(p1.grad, p2.grad) # dx2 backward_b(loss2, x2, model) # dw2 backward_w(loss2, model) # common bwd 2 ref_loss2.backward() # assert dx2 & dw2 == bwd 2 assert_close(x2.grad, ref_x2.grad) for p1, p2 in zip(model.parameters(), ref_model.parameters()): print(f"bwd2:\n p1 {p1.grad},\n p2 {p2.grad}\n") assert_close(p1, p2) assert_close(p1.grad, p2.grad) def deallocate_output_tensor(out): """Pseudo-deallocate (i.e., set to scalar) the output tensor's '.data' field. This method should be called right after the output tensor has been sent to the next pipeline stage. At this point, the output tensor is only useful for its '.grad_fn' field, and not its '.data'. """ assert isinstance(out, torch.Tensor), "expected Tensor, found %s." % type(out).__name__ assert out._base is None, "counter-productive to free a view of another tensor." out.data = torch.empty( (1,), device=out.device, dtype=out.dtype, ) IN_DIM = 8192 OUT_DIM = 8192 NUM_LAYER = 3 class MlpModel(nn.Module): def __init__(self): super().__init__() self.layers = nn.ModuleList([nn.Linear(IN_DIM, OUT_DIM, bias=None) for _ in range(NUM_LAYER)]) def forward(self, x): for layer in self.layers: x = layer(x) return x class Attention(nn.Module): def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.0, proj_drop=0.0, with_qkv=True): super().__init__() self.num_heads = num_heads head_dim = dim // num_heads self.scale = qk_scale or head_dim**-0.5 self.with_qkv = with_qkv if self.with_qkv: self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) self.attn_drop = nn.Dropout(attn_drop) def forward(self, x): B, N, C = x.shape if self.with_qkv: qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) q, k, v = qkv[0], qkv[1], qkv[2] else: qkv = x.reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) q, k, v = qkv, qkv, qkv attn = (q @ k.transpose(-2, -1)) * self.scale attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) x = (attn @ v).transpose(1, 2).reshape(B, N, C) if self.with_qkv: x = self.proj(x) x = self.proj_drop(x) return x def mem_dx_dw(): device = "cuda:0" # model = nn.Linear(IN_DIM, OUT_DIM, bias=None).to(device=device) print(f"Before init Model: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") model = MlpModel().to(device=device) print(f"model numel {get_model_numel(model)}") # 4GB print(f"After init Model: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") print(f"Before init x1&2&3: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") x1 = torch.rand(IN_DIM, OUT_DIM).to(device=device) x2 = torch.rand(IN_DIM, OUT_DIM).to(device=device) x3 = torch.rand(IN_DIM, OUT_DIM).to(device=device) x1.requires_grad_() x2.requires_grad_() x3.requires_grad_() print(f"After init x1&2&3: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") ############ # step1: ############ print(f"\nStep1") # loss1 print(f"Before Fwd x1: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") y1 = model(x1) print(f"After Fwd x1: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") print(f"Before loss1: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") loss1 = y1.sum() print(f"After loss1: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") # dx1 backward_b(loss1, x1, model) # dw1 backward_w(loss1, model) deallocate_output_tensor(x1) deallocate_output_tensor(y1) # del x1 # del y1 print(f"After del x1&y1: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") # print(f"\n Step1:collect:{gc.collect()}") # print(f"object: {gc.get_objects()}") # print(f"garbage: {gc.garbage}") ############ # step2: ############ print(f"\nStep2") # loss2 y2 = model(x2) loss2 = y2.sum() # dx2 backward_b(loss2, x2, model) # dw2 backward_w(loss2, model) deallocate_output_tensor(x2) deallocate_output_tensor(y2) # del x2 # del y2 print(f"After del x2&y2: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") print(f"\n Step2:collect:{gc.collect()}") # print(f"object: {gc.get_objects()}") print(f"garbage: {gc.garbage}") ############ # step3: ############ print(f"\nStep3") # loss3 y3 = model(x3) loss3 = y3.sum() # dx2 backward_b(loss3, x3, model) # dw2 backward_w(loss3, model) deallocate_output_tensor(x3) deallocate_output_tensor(y3) # del x3 # del y3 print(f"After del x3&y3: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") print(f"\n Step3:collect:{gc.collect()}") # print(f"object: {gc.get_objects()}") print(f"garbage: {gc.garbage}") # del activation def activation_dx_dw(): device = "cuda:0" # model = nn.Linear(IN_DIM, OUT_DIM, bias=None).to(device=device) print(f"Before init Model: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") model = MlpModel().to(device=device) x1 = torch.rand(IN_DIM, OUT_DIM).to(device=device) x2 = torch.rand(IN_DIM, OUT_DIM).to(device=device) x3 = torch.rand(IN_DIM, OUT_DIM).to(device=device) x1.requires_grad_() x2.requires_grad_() x3.requires_grad_() print(f"After init Model, x1,x2,x3: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") activations = {} def register_hooks(module): def activation_hook(module, input, output): activations[f"{module.__class__.__name__}_{id(module)}"] = output.detach() def bwd_hook(module, grad_input, grad_output): del activations[f"{module.__class__.__name__}_{id(module)}"] module.register_forward_hook(activation_hook) module.register_backward_hook(bwd_hook) model.apply(register_hooks) ############ # step1: ############ print(f"\nStep1") # loss1 loss1 = model(x1).sum() # dx1 backward_b(loss1, x1, model) # dw1 backward_w(loss1, model) del loss1, x1 print(f"After del x1&y1: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") ############ # step2: ############ print(f"\nStep2") # loss2 loss2 = model(x2).sum() # dx2 backward_b(loss2, x2, model) # dw2 backward_w(loss2, model) # deallocate_output_tensor(x2) # deallocate_output_tensor(loss2) del x2, loss2 print(f"After del x2&y2: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") ############ # step3: ############ print(f"\nStep3") # loss3 loss3 = model(x3).sum() # dx2 backward_b(loss3, x3, model) # dw2 backward_w(loss3, model) del x3, loss3 print(f"After del x3&y3: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") # text dx dw in model chunk def model_chunk_dx_dw(): device = "cuda:0" num_layers = 4 print(f"Before init Model: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") model = MlpModel(in_dim=4096, out_dim=4096, num_layers=num_layers).to(device=device) x = torch.rand(4096, 4096).to(device=device) x.requires_grad_() model_chunk_0 = torch.nn.ModuleList() # for layer 1 & 2 model_chunk_1 = torch.nn.ModuleList() # for layer 3 & 4 for idx, sub_model in enumerate(model.layers): if idx < 2: model_chunk_0.append(sub_model).cuda() else: model_chunk_1.append(sub_model).cuda() print(f"After init Model & input: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") # Step1:chunk 0 fwd activation = dict() # layer_id: activation out = x for i in range(len(model_chunk_0)): layer = model_chunk_0[i] activation[i] = layer(out) print(f"After chunk0 fwd: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") # Step2:chunk 1 fwd for i in range(len(model_chunk_1)): layer = model_chunk_0[i] activation[i + 2] = layer(out) print(f"After chunk1 fwd: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") # Step3:chunk 1 bwd b: dx=w*dy & bwd w:dw=x*dy # visit layer reversely for i in range(len(model_chunk_1) - 1, -1, -1): layer = model_chunk_1[i] global_layer_idx = i + 2 prev_global_layer_idx = i + 1 if i + 1 > 0 else None i + 3 if i + 3 < 4 else None # bwd b if global_layer_idx == num_layers - 1: # last layer in last chunk; calculate loss loss = activation[global_layer_idx].sum() x = activation[prev_global_layer_idx] backward_b(loss, x, layer) else: loss = activation[global_layer_idx].sum() x = activation[prev_global_layer_idx] backward_b(loss, x, layer) # bwd w backward_w(loss, layer) def test_dx_dw_linear_benchmark(): device = "cuda:0" model = nn.Linear(4096, 4096, bias=None).to(device=device) # print(f"model numel {get_model_numel(model)}") # 4GB x1 = torch.rand(4096, 4096).to(device=device) # x2 = torch.rand(4096, 4096).to(device=device) ref_model = deepcopy(model) ref_x1 = x1.clone() # ref_x2 = x1.clone() # first step x1.requires_grad_() # x2.requires_grad_() ref_x1.requires_grad_() # ref_x2.requires_grad_() # loss for dx_dw bwd loss1 = model(x1).sum() # loss2 = model(x2).sum() # loss for common bwd ref_model(ref_x1).sum() # ref_loss2 = ref_model(ref_x2).sum() with torch.profiler.profile( activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA], # schedule=torch.profiler.schedule(wait=1, warmup=1, active=3, repeat=1), on_trace_ready=torch.profiler.tensorboard_trace_handler( f"/home/nvme-share/home/duanjunwen/ColossalAI/tests/test_pipeline/test_schedule" ), record_shapes=True, profile_memory=True, with_stack=True, with_flops=True, ) as prof: # dx1 torch.cuda.synchronize() bwd_b_start_time = time.time() backward_b(loss1, x1, model) bwd_b_end_time = time.time() print(f"loss_1 bwd B runtime {bwd_b_end_time - bwd_b_start_time}") for p in model.parameters(): assert p.grad is None assert x1.grad is not None # dw1 torch.cuda.synchronize() bwd_w_start_time = time.time() backward_w(loss1, model) bwd_w_end_time = time.time() print(f"loss_1 bwd W runtime {bwd_w_end_time - bwd_w_start_time}") for p in model.parameters(): assert p.grad is not None # # common bwd 1 # torch.cuda.synchronize() # comm_bwd_start_time = time.time() # ref_loss1.backward() # comm_bwd_end_time = time.time() # print(f"loss_1 comm bwd runtime {comm_bwd_end_time - comm_bwd_start_time}") def test_dx_dw_attn_benchmark(): device = "cuda:0" model = Attention(dim=4096).to(device=device) # print(f"model numel {get_model_numel(model)}") # 4GB x1 = torch.rand(1, 256, 4096).to(device=device) # x2 = torch.rand(1, 256, 4096).to(device=device) ref_model = deepcopy(model) ref_x1 = x1.clone() # ref_x2 = x1.clone() # first step x1.requires_grad_() # x2.requires_grad_() ref_x1.requires_grad_() # ref_x2.requires_grad_() # loss for dx_dw bwd loss1 = model(x1).sum() # loss2 = model(x2).sum() # loss for common bwd ref_model(ref_x1).sum() # ref_loss2 = ref_model(ref_x2).sum() with torch.profiler.profile( activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA], # schedule=torch.profiler.schedule(wait=1, warmup=1, active=3, repeat=1), on_trace_ready=torch.profiler.tensorboard_trace_handler( f"/home/nvme-share/home/duanjunwen/ColossalAI/tests/test_pipeline/test_schedule" ), record_shapes=True, profile_memory=True, with_stack=True, with_flops=True, ) as prof: # dx1 torch.cuda.synchronize() bwd_b_start_time = time.time() backward_b(loss1, x1, model) bwd_b_end_time = time.time() print(f"loss_1 bwd B runtime {bwd_b_end_time - bwd_b_start_time}") for p in model.parameters(): assert p.grad is None assert x1.grad is not None # dw1 torch.cuda.synchronize() bwd_w_start_time = time.time() backward_w(loss1, model) bwd_w_end_time = time.time() print(f"loss_1 bwd W runtime {bwd_w_end_time - bwd_w_start_time}") for p in model.parameters(): assert p.grad is not None # # common bwd 1 # torch.cuda.synchronize() # comm_bwd_start_time = time.time() # ref_loss1.backward() # comm_bwd_end_time = time.time() # print(f"loss_1 comm bwd runtime {comm_bwd_end_time - comm_bwd_start_time}") if __name__ == "__main__": # test_dx_dw_split() # test_double_dx_dw_split_nsync() # test_double_dx_dw_split_sync() # mem_dx_dw() # activation_dx_dw() # test_dx_dw_linear_benchmark() test_dx_dw_attn_benchmark()