import gc from copy import deepcopy from typing import Tuple import torch import torch.distributed as dist import torch.nn as nn from torch.testing import assert_close import colossalai from colossalai.cluster import ProcessGroupMesh from colossalai.pipeline.p2p import PipelineP2PCommunication from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.testing import rerun_if_address_is_in_use, spawn IN_DIM = 8192 OUT_DIM = 8192 NUM_LAYER = 3 class MlpModel(nn.Module): def __init__(self, in_dim=IN_DIM, out_dim=OUT_DIM, num_layers=NUM_LAYER): super().__init__() self.layers = nn.ModuleList([nn.Linear(in_dim, out_dim, bias=None) for _ in range(num_layers)]) def forward(self, x): for layer in self.layers: x = layer(x) return x def get_model_numel(model: torch.nn.Module) -> Tuple[int, int]: num_params = 0 num_params_trainable = 0 for p in model.parameters(): num_params += p.numel() if p.requires_grad: num_params_trainable += p.numel() return num_params, num_params_trainable # Step1: dx = w*dy def backward_b(loss, x, model): print(f"Before bwd b: {torch.cuda.memory_allocated()/1024**3 :.3f} GB") # print(f"Before x grad {x.grad}") # for name, param in model.named_parameters(): # print(f"Before bwd b \n param {param}\n param gard {param.grad}\n") torch.autograd.backward(loss, inputs=x, retain_graph=True) # for name, param in model.named_parameters(): # print(f"After bwd b \n param {param}\n param gard {param.grad}\n") # print(f"After x grad {x.grad}") print(f"After bwd b: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") # Step1: dx = w*dy; for layer not last def backward_b_not_last(tensors, grad, x, model): print(f"Before bwd b: {torch.cuda.memory_allocated()/1024**3 :.3f} GB") torch.autograd.backward(tensors=tensors, grad_tensors=grad, inputs=x, retain_graph=True) print(f"After bwd b: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") def backward_w(loss, model): print(f"Before bwd w: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") # for name, param in model.named_parameters(): # print(f"Before bwd w \n param {param}\n param gard {param.grad}\n") torch.autograd.backward(loss, inputs=list(model.parameters())) # for name, param in model.named_parameters(): # print(f"After bwd w \n param {param}\n param gard {param.grad}\n") print(f"After bwd w: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") # Step2: dummy dw = x*dy def backward_w_not_last(tensors, grad, model): print(f"Before bwd w: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") torch.autograd.backward(tensors=tensors, grad_tensors=grad, inputs=list(model.parameters())) print(f"After bwd w: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") def test_dx_dw_split(): device = "cuda:0" model = nn.Linear(8, 8, bias=None).to(device=device) print(f"model numel {get_model_numel(model)}") # 4GB x = torch.rand(8, 8).to(device=device) ref_model = deepcopy(model) ref_x = x.clone() # first step x.requires_grad_() loss = model(x).sum() backward_b(loss, x, model) for p in model.parameters(): assert p.grad is None assert x.grad is not None backward_w(loss, model) for p in model.parameters(): assert p.grad is not None # # second step # loss = model(x).sum() # backward_b(loss, x, model) # backward_w(loss, model) ref_x.requires_grad_() ref_loss = ref_model(ref_x).sum() ref_loss.backward() assert torch.equal(x.grad, ref_x.grad) for p1, p2 in zip(model.parameters(), ref_model.parameters()): assert torch.equal(p1.grad, p2.grad) def test_double_dx_dw_split_nsync(): device = "cuda:0" model = nn.Linear(8, 8, bias=None).to(device=device) # print(f"model numel {get_model_numel(model)}") # 4GB x1 = torch.rand(8, 8).to(device=device) x2 = torch.rand(8, 8).to(device=device) ref_model = deepcopy(model) ref_x1 = x1.clone() ref_x2 = x2.clone() # first step x1.requires_grad_() x2.requires_grad_() ref_x1.requires_grad_() ref_x2.requires_grad_() # loss for dx_dw bwd loss1 = model(x1).sum() loss2 = model(x2).sum() # loss for common bwd ref_loss1 = ref_model(ref_x1).sum() ref_loss2 = ref_model(ref_x2).sum() # dx1 backward_b(loss1, x1, model) for p in model.parameters(): assert p.grad is None assert x1.grad is not None # dx2 backward_b(loss2, x2, model) # dw1 backward_w(loss1, model) for p in model.parameters(): assert p.grad is not None # common bwd 1 ref_loss1.backward() # assert dx1 & dw1 == bwd 1 assert_close(x1.grad, ref_x1.grad) for p1, p2 in zip(model.parameters(), ref_model.parameters()): assert_close(p1, p2) assert_close(p1.grad, p2.grad) # dw2 backward_w(loss2, model) # common bwd 2 ref_loss2.backward() # assert dx2 & dw2 == bwd 2 assert_close(x2.grad, ref_x2.grad) for p1, p2 in zip(model.parameters(), ref_model.parameters()): print(f"bwd2:\n p1 {p1.grad},\n p2 {p2.grad}\n") assert_close(p1, p2) assert_close(p1.grad, p2.grad) def test_double_dx_dw_split_sync(): device = "cuda:0" model = nn.Linear(8, 8, bias=None).to(device=device) # print(f"model numel {get_model_numel(model)}") # 4GB x1 = torch.rand(8, 8).to(device=device) x2 = torch.rand(8, 8).to(device=device) # x1 = torch.ones(8, 8).to(device=device) # x2 = torch.ones(8, 8).to(device=device) ref_model = deepcopy(model) ref_x1 = x1.clone() ref_x2 = x2.clone() x1.requires_grad_() x2.requires_grad_() ref_x1.requires_grad_() ref_x2.requires_grad_() ############ # step1: ############ print(f"Step1\n") # loss1 loss1 = model(x1).sum() # ref_loss1 ref_loss1 = ref_model(ref_x1).sum() # dx1 backward_b(loss1, x1, model) for p in model.parameters(): assert p.grad is None assert x1.grad is not None # dw1 backward_w(loss1, model) for p in model.parameters(): assert p.grad is not None # common bwd 1 ref_loss1.backward() # assert dx1 & dw1 == bwd 1 assert_close(x1.grad, ref_x1.grad) for p1, p2 in zip(model.parameters(), ref_model.parameters()): assert_close(p1, p2) assert_close(p1.grad, p2.grad) ############ # step2: ############ print(f"Step2\n") # loss2 loss2 = model(x2).sum() # ref_loss2 ref_loss2 = ref_model(ref_x2).sum() for p1, p2 in zip(model.parameters(), ref_model.parameters()): # print(f"bwd2:\n p1 {p1.grad},\n p2 {p2.grad}\n") assert_close(p1, p2) assert_close(p1.grad, p2.grad) # dx2 backward_b(loss2, x2, model) # dw2 backward_w(loss2, model) # common bwd 2 ref_loss2.backward() # assert dx2 & dw2 == bwd 2 assert_close(x2.grad, ref_x2.grad) for p1, p2 in zip(model.parameters(), ref_model.parameters()): # print(f"bwd2:\n p1 {p1.grad},\n p2 {p2.grad}\n") assert_close(p1, p2) assert_close(p1.grad, p2.grad) def deallocate_output_tensor(out): """Pseudo-deallocate (i.e., set to scalar) the output tensor's '.data' field. This method should be called right after the output tensor has been sent to the next pipeline stage. At this point, the output tensor is only useful for its '.grad_fn' field, and not its '.data'. """ assert isinstance(out, torch.Tensor), "expected Tensor, found %s." % type(out).__name__ assert out._base is None, "counter-productive to free a view of another tensor." out.data = torch.empty( (1,), device=out.device, dtype=out.dtype, ) # del loss and x def mem_dx_dw(): device = "cuda:0" # model = nn.Linear(IN_DIM, OUT_DIM, bias=None).to(device=device) print(f"Before init Model: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") model = MlpModel().to(device=device) print(f"model numel {get_model_numel(model)}") # 4GB print(f"After init Model: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") print(f"Before init x1&2&3: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") x1 = torch.rand(IN_DIM, OUT_DIM).to(device=device) x2 = torch.rand(IN_DIM, OUT_DIM).to(device=device) x3 = torch.rand(IN_DIM, OUT_DIM).to(device=device) x1.requires_grad_() x2.requires_grad_() x3.requires_grad_() print(f"After init x1&2&3: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") ############ # step1: ############ print(f"\nStep1") # loss1 print(f"Before Fwd x1: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") loss1 = model(x1).sum() print(f"After Fwd x1: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") print(f"Before loss1: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") print(f"After loss1: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") # dx1 backward_b(loss1, x1, model) # dw1 backward_w(loss1, model) # deallocate_output_tensor(x1) # deallocate_output_tensor(loss1) del loss1, x1 # del x1 # del y1 print(f"After del x1&y1: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") ############ # step2: ############ print(f"\nStep2") # loss2 loss2 = model(x2).sum() # dx2 backward_b(loss2, x2, model) # dw2 backward_w(loss2, model) # deallocate_output_tensor(x2) # deallocate_output_tensor(loss2) del x2, loss2 # del x2 # del y2 print(f"After del x2&y2: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") ############ # step3: ############ print(f"\nStep3") # loss3 loss3 = model(x3).sum() # dx2 backward_b(loss3, x3, model) # dw2 backward_w(loss3, model) # deallocate_output_tensor(x3) # deallocate_output_tensor(loss3) # del x3 # del y3 del x3, loss3 print(f"After del x3&y3: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") param_ids = [id(p) for p in model.parameters()] for obj in gc.get_objects(): if torch.is_tensor(obj) and id(obj) not in param_ids: print(obj) # del activation def activation_dx_dw(): device = "cuda:0" # model = nn.Linear(IN_DIM, OUT_DIM, bias=None).to(device=device) print(f"Before init Model: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") model = MlpModel().to(device=device) x1 = torch.rand(IN_DIM, OUT_DIM).to(device=device) x2 = torch.rand(IN_DIM, OUT_DIM).to(device=device) x3 = torch.rand(IN_DIM, OUT_DIM).to(device=device) x1.requires_grad_() x2.requires_grad_() x3.requires_grad_() print(f"After init Model, x1,x2,x3: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") # activations = {} # def register_hooks(module): # def activation_hook(module, input, output): # activations[f"{module.__class__.__name__}_{id(module)}"] = output.detach() # def bwd_hook(module, grad_input, grad_output): # del activations[f"{module.__class__.__name__}_{id(module)}"] # module.register_forward_hook(activation_hook) # module.register_backward_hook(bwd_hook) # model.apply(register_hooks) ############ # step1: ############ print(f"\nStep1") # loss1 output1 = model(x1) loss1 = output1.sum() # dx1 backward_b(loss1, x1, model) # for name, p in model.named_parameters(): # print(f"p grad {p.grad}") # dw1 backward_w(loss1, model) # for name, p in model.named_parameters(): # del p.grad # del loss1, x1 del loss1, x1, output1 print(f"After del : {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") ############ # step2: ############ print(f"\nStep2") # loss2 output2 = model(x2) loss2 = output2.sum() # dx2 backward_b(loss2, x2, model) # for name, p in model.named_parameters(): # print(f"p grad {p.grad}") # dw2 backward_w(loss2, model) # for name, p in model.named_parameters(): # print(f"p grad {p.grad}") # del x2, loss2 del x2, loss2, output2 print(f"After del : {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") ############ # step3: ############ print(f"\nStep3") # loss3 output3 = model(x3) loss3 = output3.sum() # dx2 backward_b(loss3, x3, model) # dw2 backward_w(loss3, model) # del x3, loss3 del x3, loss3, output3 print(f"After del : {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") def model_chunk_dx_dw(): device = "cuda:0" num_layers = 4 print(f"Before init Model: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") model = MlpModel(in_dim=4096, out_dim=4096, num_layers=num_layers).to(device=device) input = torch.rand(4096, 4096, requires_grad=True).to(device=device) input_base = input.clone() model_base = deepcopy(model) ########################## # Fwd bwd for dx dw ########################## model_chunk_0 = torch.nn.Sequential() # for layer 1 & 2 model_chunk_1 = torch.nn.Sequential() # for layer 3 & 4 for idx, sub_model in enumerate(model.layers): if idx < 2: model_chunk_0.append(sub_model) else: model_chunk_1.append(sub_model) print(f"After init Model & input: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") ########################## # Step1:chunk 0 fwd ########################## output1 = model_chunk_0(input) # detach output1; then output1 for chunk 0, output1_dt for chunk 1; output1_dt = output1.detach() output1_dt.requires_grad_() print(f"After chunk0 fwd (include detach output1): {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") ########################## # Step2:chunk 1 fwd ########################## output2 = model_chunk_1(output1_dt) print(f"After chunk1 fwd: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") ########################## # Step3:chunk 1 bwd b: dx=w*dy & bwd w:dw=x*dy ########################## loss = output2.mean() backward_b(loss, output1_dt, model_chunk_1) backward_w(loss, model_chunk_1) print(f"After chunk1 bwd b & w: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") ########################## # Step4:chunk 0 bwd b: dx=w*dy & bwd w:dw=x*dy ########################## # dx = w*dy backward_b_not_last(tensors=output1, grad=output1_dt.grad, x=input, model=model_chunk_0) backward_w_not_last(tensors=output1, grad=output1_dt.grad, model=model_chunk_0) print(f"After chunk0 bwd b & w: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") ########################## # Fwd bwd for base ########################## # fwd & bwd output_base = model_base(input_base) loss_base = output_base.mean() loss_base.backward() print(f"After base fwd & bwd: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") ########################## # Assert param ########################## assert_close(output2, output_base) assert_close(output2.grad, output_base.grad) for p1, p2 in zip(model.parameters(), model_base.parameters()): assert_close(p1, p2) assert_close(p1.grad, p2.grad) del output1, output1_dt, output2, loss, loss_base, output_base print(f"After del: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") def model_chunk_dx_dw_communication( rank: int, world_size: int, port: int, ): # init dist colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost") pg_mesh = ProcessGroupMesh(world_size) stage_manager = PipelineStageManager(pg_mesh, pipeline_axis=0, enable_interleave=True, num_model_chunks=2) rank = dist.get_rank() comm = PipelineP2PCommunication(stage_manager, overlap_p2p=False) print(f"{stage_manager.get_rank()}") # init model and input num_layers = 4 print(f"Before init Model: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};") model = MlpModel(in_dim=4096, out_dim=4096, num_layers=num_layers).to(rank) input = torch.rand(4096, 4096, requires_grad=True).to(rank) input_base = input.clone() model_base = deepcopy(model) if rank == 0: model_chunk_0 = torch.nn.Sequential().to(rank) # for layer 1 & 2 on rank0 for idx, sub_model in enumerate(model.layers): if idx < 2: model_chunk_0.append(sub_model) else: model_chunk_1 = torch.nn.Sequential().to(rank) # for layer 3 & 4 on rank1 for idx, sub_model in enumerate(model.layers): if idx >= 2: model_chunk_1.append(sub_model) print( f"After init Model & input: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};" ) ########################## # Step1:chunk 0 fwd ########################## if rank == 0: output1 = model_chunk_0(input) # detach output1; then output1 for chunk 0, output1_dt for chunk 1; # output1_dt_rank0 = output1.detach() # output1_dt_rank0.requires_grad_() print( f"After chunk0 fwd (include detach output1): {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};" ) # send y(output1_dt) to next stage comm.send_forward(output1, stage_manager.get_next_rank()) ########################## # Step2:chunk 1 fwd ########################## if rank == 1: # recv y(output1_dt) from prev stage output1_dt_rank1, wait_handles = comm.recv_forward(stage_manager.get_prev_rank()) output1_dt_rank1.requires_grad_() output2 = model_chunk_1(output1_dt_rank1) print( f"After chunk1 fwd: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};" ) ########################## # Step3:chunk 1 on device_1 bwd b: dx=w*dy & bwd w:dw=x*dy ########################## if rank == 1: loss = output2.mean() backward_b(loss, output1_dt_rank1, model_chunk_1) backward_w(loss, model_chunk_1) print(f"After chunk1 bwd b & w: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") # send bwd output1_dt_rank1 from rank1 to rank 0 comm.send_backward(output1_dt_rank1.grad, stage_manager.get_prev_rank()) ########################## # Step4:chunk 0 on device_0 bwd b: dx=w*dy & bwd w:dw=x*dy ########################## if rank == 0: # recv bwd output1_dt_rank1 from rank1 to rank 0 output1_dt_rank0_grad, _ = comm.recv_backward(stage_manager.get_next_rank()) backward_b_not_last(tensors=output1, grad=output1_dt_rank0_grad, x=input, model=model_chunk_0) backward_w_not_last(tensors=output1, grad=output1_dt_rank0_grad, model=model_chunk_0) print(f"After chunk0 bwd b & w: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") ########################## # Fwd bwd for base ########################## # fwd & bwd output_base = model_base(input_base) loss_base = output_base.mean() loss_base.backward() print(f"After base fwd & bwd: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") ########################## # Assert param ########################## # assert output if rank == 1: assert_close(output2, output_base) assert_close(output2.grad, output_base.grad) # assert model param & grad if rank == 0: count = 0 for (chunk_name, chunk_param), (base_name, base_param) in zip( model_chunk_0.named_parameters(), model_base.named_parameters() ): if count < 2: assert_close(chunk_param, base_param) assert_close(chunk_param.grad, base_param.grad) count += 1 if rank == 1: count = 0 for (chunk_name, chunk_param), (base_name, base_param) in zip( model_chunk_1.named_parameters(), model_base.named_parameters() ): if count >= 2: assert_close(chunk_param, base_param) assert_close(chunk_param.grad, base_param.grad) count += 1 # clean memory if rank == 0: del output1, output1_dt_rank0_grad if rank == 1: del output2, loss, output1_dt_rank1 del loss_base, output_base print(f"After del: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};") # Return: output, loss def schedule_f( stage_manager: PipelineStageManager, comm: PipelineP2PCommunication, input: torch.Tensor, model_chunk: torch.nn.ModuleList, model_chunk_id: int, ): # chunk_id == 0 if model_chunk_id == 0: # recv fwd from prev if stage_manager.is_first_stage(ignore_chunk=True): input = input # get local input else: prev_rank = stage_manager.get_prev_rank() input, wait_handles = comm.recv_forward(prev_rank) # fwd step output = model_chunk[model_chunk_id](input) # send fwd to next if stage_manager.is_last_stage(ignore_chunk=True): return input, output, None # return local output else: next_rank = stage_manager.get_next_rank() comm.send_forward(output, next_rank) # chunk_id == 1 if model_chunk_id == 1: # recv fwd from next if stage_manager.is_last_stage(ignore_chunk=True): input = input # get local input else: next_rank = stage_manager.get_next_rank() input, wait_handles = comm.recv_forward(next_rank) # fwd step output = model_chunk[model_chunk_id](input) # send fwd to prev if stage_manager.is_first_stage(ignore_chunk=True): loss = output.mean() return input, output, loss # return local output else: prev_rank = stage_manager.get_prev_rank() comm.send_forward(output, prev_rank) return input, output, None def schedule_b( stage_manager: PipelineStageManager, comm: PipelineP2PCommunication, input: torch.Tensor, # x output: torch.Tensor, # y output_grad: torch.Tensor, # dy model_chunk: torch.nn.ModuleList, model_chunk_id: int, ): # chunk_id == 0 if model_chunk_id == 0: # recv bwd from next if stage_manager.is_last_stage(ignore_chunk=True): output_grad = output_grad # get dy from local else: next_rank = stage_manager.get_next_rank() output_grad, _ = comm.recv_backward(next_rank) # bwd step backward_b_not_last(tensors=output, grad=output_grad, x=input, model=model_chunk[model_chunk_id]) backward_w_not_last(tensors=output, grad=output_grad, model=model_chunk[model_chunk_id]) # send bwd to prev if stage_manager.is_first_stage(ignore_chunk=True): return input.grad else: prev_rank = stage_manager.get_prev_rank() comm.send_backward(input.grad, prev_rank) # chunk_id == 1 if model_chunk_id == 1: # recv bwd from prev if stage_manager.is_first_stage(ignore_chunk=True): output_grad = output_grad else: prev_rank = stage_manager.get_prev_rank() # print(f"prev_rank {prev_rank} curr rank {stage_manager.get_rank()}") output_grad, _ = comm.recv_backward(next_rank=prev_rank) # bwd step # print(f"Before input grad {input.grad}") # for name, param in model_chunk[model_chunk_id].named_parameters(): # print(f"Before {name} grad {param.grad}") if stage_manager.is_first_stage(ignore_chunk=True): backward_b(loss=output_grad, x=input, model=model_chunk[model_chunk_id]) backward_w(loss=output_grad, model=model_chunk[model_chunk_id]) else: # commom bwd step # print(f"output_grad {output_grad}") backward_b_not_last(tensors=output, grad=output_grad, x=input, model=model_chunk[model_chunk_id]) backward_w_not_last(tensors=output, grad=output_grad, model=model_chunk[model_chunk_id]) # print(f"After input grad {input.grad}") # for name, param in model_chunk[model_chunk_id].named_parameters(): # print(f"After {name} grad {param.grad}") # send bwd to next if stage_manager.is_last_stage(ignore_chunk=True): return input.grad else: next_rank = stage_manager.get_next_rank() comm.send_backward(input.grad, next_rank) return input.grad def schedule_w(): pass def model_chunk_dx_dw_comm_interleaved( rank: int, world_size: int, port: int, ): # init dist colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost") pg_mesh = ProcessGroupMesh(world_size) stage_manager = PipelineStageManager(pg_mesh, pipeline_axis=0, enable_interleave=True, num_model_chunks=world_size) rank = dist.get_rank() comm = PipelineP2PCommunication(stage_manager, overlap_p2p=False) # init model and input num_layers = 8 in_dim = out_dim = 2048 print(f"Before init Model: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};") model = MlpModel(in_dim=in_dim, out_dim=out_dim, num_layers=num_layers).to(rank) input0 = torch.rand(in_dim, out_dim, requires_grad=True).to(rank) input_base = input0.clone() model_base = deepcopy(model) if rank == 0: # layer 0 & 7 to chunk 0 on rank0 chunk_0 = torch.nn.ModuleList().to(rank) for idx, sub_model in enumerate(model.layers): if idx == 0 or idx == 7: chunk_0.append(sub_model) elif rank == 1: # layer 1 & 6 to chunk 1 on rank1 chunk_1 = torch.nn.ModuleList().to(rank) for idx, sub_model in enumerate(model.layers): if idx == 1 or idx == 6: chunk_1.append(sub_model) elif rank == 2: # layer 2 & 5 to chunk 2 on rank2 chunk_2 = torch.nn.ModuleList().to(rank) for idx, sub_model in enumerate(model.layers): if idx == 2 or idx == 5: chunk_2.append(sub_model) else: # layer 3 & 4 to chunk 3 on rank3 chunk_3 = torch.nn.Sequential().to(rank) for idx, sub_model in enumerate(model.layers): if idx == 3 or idx == 4: chunk_3.append(sub_model) # # test checkpoint # check_fn = lambda submodule: isinstance(submodule, (Linear)) # non_reentrant_wrapper = partial( # checkpoint_wrapper, # # checkpoint_impl=CheckpointImpl.NO_REENTRANT, # checkpoint_impl=CheckpointImpl.REENTRANT, # ) # apply_activation_checkpointing( # model, checkpoint_wrapper_fn=non_reentrant_wrapper, check_fn=check_fn # ) print( f"After init Model & input: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};" ) # set_checkpoint_early_stop(False) # buffer use to save input and output ########################## # Step1: fwd ########################## ###### # fwd 1->4 ###### # chunk 0 id 0 (layer 0) fwd if rank == 0: chunk_id = 0 input0, output0, _ = schedule_f( stage_manager=stage_manager, comm=comm, input=input0, model_chunk=chunk_0, model_chunk_id=chunk_id, ) print( f"chunk 0 id 0 (layer 0)fwd: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};" ) # chunk 1 id 0 (layer 1) fwd if rank == 1: chunk_id = 0 input1, output1, _ = schedule_f( stage_manager=stage_manager, comm=comm, input=None, model_chunk=chunk_1, model_chunk_id=chunk_id, ) print( f"chunk 1 id 0 (layer 1)fwd: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};" ) # chunk 2 id 0 (layer 2) fwd if rank == 2: chunk_id = 0 input2, output2, _ = schedule_f( stage_manager=stage_manager, comm=comm, input=None, model_chunk=chunk_2, model_chunk_id=chunk_id, ) print( f"chunk 2 id 0 (layer 2)fwd: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};" ) # chunk 3 id 0 (layer 3) fwd if rank == 3: chunk_id = 0 input3, output3, _ = schedule_f( stage_manager=stage_manager, comm=comm, input=None, model_chunk=chunk_3, model_chunk_id=chunk_id, ) print( f"chunk 3 id 0 (layer 3)fwd: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};" ) ###### # fwd 4->1 ###### if rank == 3: chunk_id = 1 input4, output4, _ = schedule_f( stage_manager=stage_manager, comm=comm, input=output3, model_chunk=chunk_3, model_chunk_id=chunk_id, ) print( f"chunk 3 id 1 (layer 4)fwd: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};" ) if rank == 2: chunk_id = 1 input5, output5, _ = schedule_f( stage_manager=stage_manager, comm=comm, input=None, model_chunk=chunk_2, model_chunk_id=chunk_id, ) print( f"chunk 2 id 1 (layer 5)fwd: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};" ) if rank == 1: chunk_id = 1 input6, output6, _ = schedule_f( stage_manager=stage_manager, comm=comm, input=None, model_chunk=chunk_1, model_chunk_id=chunk_id, ) print( f"chunk 1 id 1 (layer 6)fwd: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};" ) if rank == 0: chunk_id = 1 input7, output7, loss = schedule_f( stage_manager=stage_manager, comm=comm, input=None, model_chunk=chunk_0, model_chunk_id=chunk_id, ) # print(f"fwd output {output7}") print( f"chunk 0 id 1 (layer 7)fwd: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};" ) ########################## # Step2: bwd ########################## ###### # bwd rank 4->1 ###### # chunk 0 id 1 (layer 7) bwd if rank == 0: chunk_id = 1 input_grad7 = schedule_b( stage_manager=stage_manager, comm=comm, input=input7, # x output=output7, # y output_grad=loss, # dy model_chunk=chunk_0, model_chunk_id=chunk_id, ) # # chunk 1 id 1 (layer 6) bwd if rank == 1: chunk_id = 1 input_grad6 = schedule_b( stage_manager=stage_manager, comm=comm, input=input6, # x output=output6, # y output_grad=None, # dy model_chunk=chunk_1, model_chunk_id=chunk_id, ) # chunk 2 id 1 (layer 5) bwd if rank == 2: chunk_id = 1 input_grad5 = schedule_b( stage_manager=stage_manager, comm=comm, input=input5, # x output=output5, # y output_grad=None, # dy model_chunk=chunk_2, model_chunk_id=chunk_id, ) # chunk 3 id 1 (layer 4) bwd if rank == 3: chunk_id = 1 input_grad4 = schedule_b( stage_manager=stage_manager, comm=comm, input=input4, # x output=output4, # y output_grad=None, # dy model_chunk=chunk_3, model_chunk_id=chunk_id, ) # print(f"input_grad4 {input_grad4}") ###### # bwd rank 1->4 ###### # chunk 3 id 0 (layer 3) bwd if rank == 3: chunk_id = 0 input_grad3 = schedule_b( stage_manager=stage_manager, comm=comm, input=input3, # x output=output3, # y output_grad=input_grad4, # dy model_chunk=chunk_3, model_chunk_id=chunk_id, ) # print(f"input_grad3 {input_grad3}") # chunk 2 id 0 (layer 2) bwd if rank == 2: chunk_id = 0 input_grad2 = schedule_b( stage_manager=stage_manager, comm=comm, input=input2, # x output=output2, # y output_grad=None, # dy model_chunk=chunk_2, model_chunk_id=chunk_id, ) # print(f"input_grad2 {input_grad2}") # chunk 1 id 0 (layer 1) bwd if rank == 1: chunk_id = 0 input_grad1 = schedule_b( stage_manager=stage_manager, comm=comm, input=input1, # x output=output1, # y output_grad=None, # dy model_chunk=chunk_1, model_chunk_id=chunk_id, ) # chunk 0 id 0 (layer 0) bwd if rank == 0: chunk_id = 0 input_grad0 = schedule_b( stage_manager=stage_manager, comm=comm, input=input0, # x output=output0, # y output_grad=None, # dy model_chunk=chunk_0, model_chunk_id=chunk_id, ) # print(f"input_grad0 {input_grad0}") ########################## # Fwd bwd for base ########################## # fwd & bwd output_base = model_base(input_base) loss_base = output_base.mean() loss_base.backward() print(f"After base fwd & bwd: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") ########################## # Assert close ########################## # assert output if rank == 0: assert_close(output7, output_base) # assert weight if rank == 0: # layer 0 assert_close(chunk_0[0].weight, model_base.layers[0].weight) assert_close(chunk_0[0].weight.grad, model_base.layers[0].weight.grad) # layer 7 assert_close(chunk_0[1].weight, model_base.layers[7].weight) assert_close(chunk_0[1].weight.grad, model_base.layers[7].weight.grad) if rank == 1: # layer 1 assert_close(chunk_1[0].weight, model_base.layers[1].weight) assert_close(chunk_1[0].weight.grad, model_base.layers[1].weight.grad) # layer 6 assert_close(chunk_1[1].weight, model_base.layers[6].weight) assert_close(chunk_1[1].weight.grad, model_base.layers[6].weight.grad) if rank == 2: # layer 2 assert_close(chunk_2[0].weight, model_base.layers[2].weight) assert_close(chunk_2[0].weight.grad, model_base.layers[2].weight.grad) # layer 5 assert_close(chunk_2[1].weight, model_base.layers[5].weight) assert_close(chunk_2[1].weight.grad, model_base.layers[5].weight.grad) if rank == 3: # layer 3 assert_close(chunk_3[0].weight, model_base.layers[3].weight) assert_close(chunk_3[0].weight.grad, model_base.layers[3].weight.grad) # layer 4 assert_close(chunk_3[1].weight, model_base.layers[4].weight) assert_close(chunk_3[1].weight.grad, model_base.layers[4].weight.grad) # clean memory if rank == 0: del input0, output0, input_grad0, input7, output7, input_grad7, loss if rank == 1: del input1, output1, input_grad1, input6, output6, input_grad6 if rank == 2: del input2, output2, input_grad2, input5, output5, input_grad5 if rank == 3: del input3, output3, input_grad3, input4, output4, input_grad4 # print(f"After del device: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};") del loss_base, output_base print(f"After del: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};") @rerun_if_address_is_in_use() def test_dx_dw_dist(): # spawn( # model_chunk_dx_dw_communication, # nprocs=2, # ) spawn( model_chunk_dx_dw_comm_interleaved, nprocs=4, ) if __name__ == "__main__": # test_dx_dw_split() # test_double_dx_dw_split_nsync() # test_double_dx_dw_split_sync() # mem_dx_dw() # activation_dx_dw() # model_chunk_dx_dw() test_dx_dw_dist()