diff --git a/colossalai/pipeline/schedule/zero_bubble_pp.py b/colossalai/pipeline/schedule/zero_bubble_pp.py index c928a207c..089ca48ee 100644 --- a/colossalai/pipeline/schedule/zero_bubble_pp.py +++ b/colossalai/pipeline/schedule/zero_bubble_pp.py @@ -11,6 +11,7 @@ from colossalai.interface import OptimizerWrapper from colossalai.pipeline.p2p import PipelineP2PCommunication from colossalai.pipeline.schedule.v_schedule import ScheduledNode from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.pipeline.weight_grad_store import WeightGradStore from ._utils import ( clone, @@ -650,10 +651,10 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): # Do not release_tensor_data loss, release_tensor_data other output_obj; if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True): self.output_tensors[model_chunk_id].append(output_obj) - self.output_tensors_dw[model_chunk_id].append(output_obj) + # self.output_tensors_dw[model_chunk_id].append(output_obj) else: self.output_tensors[model_chunk_id].append(output_obj) - self.output_tensors_dw[model_chunk_id].append(output_obj) + # self.output_tensors_dw[model_chunk_id].append(output_obj) # add output to send_fwd_buffer if model_chunk_id == 0: # chunk 0 @@ -705,13 +706,13 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): input_obj = self.input_tensors[model_chunk_id].pop(0) output_obj = self.output_tensors[model_chunk_id].pop(0) - # save output_tensor_grad for dw - if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True): - # we save loss here - self.output_tensors_grad_dw[model_chunk_id].append(output_obj) - else: - # we save output_tensor_grad here - self.output_tensors_grad_dw[model_chunk_id].append(output_tensor_grad) + # # save output_tensor_grad for dw + # if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True): + # # we save loss here + # self.output_tensors_grad_dw[model_chunk_id].append(output_obj) + # else: + # # we save output_tensor_grad here + # self.output_tensors_grad_dw[model_chunk_id].append(output_tensor_grad) # Step2: bwd step input_object_grad = self.backward_b_step( @@ -738,6 +739,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): # send to next else: self.send_backward_buffer[model_chunk_id].append(input_object_grad) + WeightGradStore.flush(chunk=model_chunk_id) def schedule_w( self, @@ -757,16 +759,18 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): """ # get y & dy from buffer - output_obj = self.output_tensors_dw[model_chunk_id].pop(0) - output_obj_grad = self.output_tensors_grad_dw[model_chunk_id].pop(0) + # output_obj = self.output_tensors_dw[model_chunk_id].pop(0) + # output_obj_grad = self.output_tensors_grad_dw[model_chunk_id].pop(0) - self.backward_w_step( - model_chunk=model_chunk, - model_chunk_id=model_chunk_id, - optimizer=optimizer, - output_obj=output_obj, - output_obj_grad=output_obj_grad, - ) + WeightGradStore.pop(chunk=model_chunk_id) + + # self.backward_w_step( + # model_chunk=model_chunk, + # model_chunk_id=model_chunk_id, + # optimizer=optimizer, + # output_obj=output_obj, + # output_obj_grad=output_obj_grad, + # ) def run_forward_only( self, diff --git a/colossalai/pipeline/weight_grad_store.py b/colossalai/pipeline/weight_grad_store.py new file mode 100644 index 000000000..5d7f76649 --- /dev/null +++ b/colossalai/pipeline/weight_grad_store.py @@ -0,0 +1,106 @@ +import queue + +# from megatron import get_args +# from megatron.core import parallel_state +# from megatron.core.distributed.finalize_model_grads import _allreduce_embedding_grads +# from megatron.core.utils import get_model_config, get_attr_wrapped_model + + +class WeightGradStore: + + cache = [] + weight_grad_queue = [queue.Queue(), queue.Queue()] + + @classmethod + def put(cls, total_input, grad_output, weight, func): + # func(total_input, grad_output, weight.main_grad) + cls.cache.append((total_input, grad_output, weight, func)) + + @classmethod + def flush(cls, chunk=0): + cls.weight_grad_queue[chunk].put(cls.cache) + cls.cache = [] + + @classmethod + def pop(cls, chunk=0): + if cls.weight_grad_queue[chunk].qsize() > 0: + stored_grads = cls.weight_grad_queue[chunk].get() + for total_input, grad_output, weight, func in stored_grads: + if weight.grad is not None: + func(total_input, grad_output, weight.grad) + # for first bwd; weight.grad is None, assign grad_weight to weight.grad + else: + grad_weight = func(total_input, grad_output) + weight.grad = grad_weight + else: + raise Exception("Pop empty queue.") + + # @classmethod + # def clear(cls, model, chunk=0): + # weight_grad_tasks = [] + # while cls.weight_grad_queue[chunk].qsize() > 0: + # stored_grads = cls.weight_grad_queue[chunk].get() + # if len(weight_grad_tasks) == 0: + # for _ in stored_grads: + # weight_grad_tasks.append([]) + # else: + # assert len(weight_grad_tasks) == len(stored_grads) + # for i, task in enumerate(stored_grads): + # weight_grad_tasks[i].append(task) + # weight_params = [] + # handles = [] + # if get_args().overlap_grad_reduce: + # handles += model.async_reduce_grad() + + # output_layer_weight = None + # if parallel_state.is_pipeline_last_stage(): + # assert len(weight_grad_tasks) > 0 + # output_layer_grads = weight_grad_tasks[0] + # for j in range(len(output_layer_grads)): + # total_input, grad_output, weight, func = output_layer_grads[j] + # if output_layer_weight is None: + # output_layer_weight = weight + # assert output_layer_weight is weight + # func(total_input, grad_output, weight.main_grad) + # output_layer_grads[j] = None # release memory + # weight_grad_tasks = weight_grad_tasks[1:] + # if get_args().overlap_grad_reduce: + # handles += model.async_reduce_grad(output_layer_weight) + + # if parallel_state.is_pipeline_first_stage() or parallel_state.is_pipeline_last_stage(): + # model_module = get_attr_wrapped_model(model, 'pre_process', return_model_obj=True) + # if model_module.share_embeddings_and_output_weights: + # # if share_embeddings_and_output_weights, wait all-reduce for embeddings + # for handle in handles: + # if handle is not None: + # handle.wait() + # handles = [] + + # config = get_model_config(model) + # # Do async all-reduce for embedding grads firstly, so that the rank 0 won't + # # be blocked + # embedding_handles = _allreduce_embedding_grads([model], config, async_op=True) + # handles += embedding_handles + + # for i in range(len(weight_grad_tasks)): + # tasks = weight_grad_tasks[i] + # param = None + # for j in range(len(tasks)): + # total_input, grad_output, weight, func = tasks[j] + # if param is None: + # param = weight + # assert param is weight + # assert not (weight is output_layer_weight) + # func(total_input, grad_output, weight.main_grad) + # tasks[j] = None # release memory + # weight_params.append(param) + # if get_args().overlap_grad_reduce: + # # All-reduce param grad here + # handles += model.async_reduce_grad(param) + # weight_grad_tasks[i] = None # release memory + + # # timers('wait_all_reduce', log_level=1).start(barrier=False) + # for handle in embedding_handles: + # if handle is not None: + # handle.wait() + # # timers('wait_all_reduce').stop() diff --git a/colossalai/shardformer/layer/_operation.py b/colossalai/shardformer/layer/_operation.py index aec823567..626a009ec 100644 --- a/colossalai/shardformer/layer/_operation.py +++ b/colossalai/shardformer/layer/_operation.py @@ -1,7 +1,11 @@ +import functools + import torch import torch.distributed as dist import torch.nn.functional as F +from colossalai.pipeline.weight_grad_store import WeightGradStore + from .utils import is_share_sp_tp try: @@ -125,12 +129,13 @@ class LinearWithAsyncCommunication(torch.autograd.Function): """ @staticmethod - def forward(ctx, input_, weight, bias, process_group, async_grad_allreduce, fp8_communication=False): + def forward(ctx, input_, weight, bias, process_group, async_grad_allreduce, fp8_communication=False, use_zbv=True): ctx.save_for_backward(input_, weight, bias) ctx.use_bias = bias is not None ctx.process_group = process_group ctx.async_grad_allreduce = async_grad_allreduce ctx.fp8_communication = fp8_communication + ctx.use_zbv = use_zbv if bias is not None: output = F.linear(input_, weight, bias) else: @@ -143,6 +148,14 @@ class LinearWithAsyncCommunication(torch.autograd.Function): input, weight, bias = ctx.saved_tensors use_bias = ctx.use_bias fp8_communication = ctx.fp8_communication + use_zbv = ctx.use_zbv + + def execute_w_pass_grad_accum(_input_, _grad_output_, _weight_main_grad_, wgrad_gemm_accum_func=None): + wgrad_gemm_accum_func(_input_, _grad_output_, _weight_main_grad_) + + def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_func=None): + # _grad_output_.t().matmul(_input_) + return wgrad_gemm_func(_grad_output_.t(), _input_) # In order to be hooked into Gemini's '__torch_function__', adding a view operation to bias. if use_bias: @@ -167,22 +180,60 @@ class LinearWithAsyncCommunication(torch.autograd.Function): if _grad_accum_fusion_available and weight.grad is not None: grad = weight.grad - if grad.dtype == torch.float32: - fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(total_input, grad_output, grad) - grad_weight = None - elif grad.dtype == torch.float16: - fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(total_input, grad_output, grad) + if use_zbv: + # TODO: append input, grad_output_, weight, grad func to WeightGradStore + if grad.dtype == torch.float32: + WeightGradStore.put( + total_input, + grad_output, + weight, + functools.partial( + execute_w_pass_grad_accum, + wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32, + ), + ) + grad_weight = None + elif grad.dtype in (torch.float16, torch.bfloat16): + WeightGradStore.put( + total_input, + grad_output, + weight, + functools.partial( + execute_w_pass_grad_accum, + wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16, + ), + ) + grad_weight = None + else: + raise RuntimeError("Unsupported gradient type for gradient accumulation fusion") + else: + if grad.dtype == torch.float32: + fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(total_input, grad_output, grad) + grad_weight = None + elif grad.dtype == torch.float16: + fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(total_input, grad_output, grad) + grad_weight = None + else: + grad_weight = grad_output.t().matmul(total_input) + else: + if use_zbv: + WeightGradStore.put( + total_input, + grad_output, + weight, + functools.partial( + execute_w_pass, + wgrad_gemm_func=torch.matmul, + ), + ) grad_weight = None else: grad_weight = grad_output.t().matmul(total_input) - else: - grad_weight = grad_output.t().matmul(total_input) grad_bias = grad_output.sum(dim=0) if use_bias else None if ctx.async_grad_allreduce and not fp8_communication: handle.wait() - return grad_input, grad_weight, grad_bias, None, None, None, None diff --git a/colossalai/shardformer/layer/linear.py b/colossalai/shardformer/layer/linear.py index d77dd4965..25f4228a4 100644 --- a/colossalai/shardformer/layer/linear.py +++ b/colossalai/shardformer/layer/linear.py @@ -201,7 +201,6 @@ class Linear1D_Col(ParallelModule): # Matrix multiply. bias = self.bias if not self.skip_bias_add else None - if self.seq_parallel_mode == "split_gather": input_parallel = gather_forward_reducescatter_backward( input_parallel, self.process_group, self.seq_parallel_dim, fp8_communication=self.fp8_communication diff --git a/examples/language/llama/benchmark.py b/examples/language/llama/benchmark.py index 0f418edb6..4f2c45d75 100644 --- a/examples/language/llama/benchmark.py +++ b/examples/language/llama/benchmark.py @@ -5,6 +5,8 @@ import warnings from contextlib import nullcontext import torch + +torch.autograd.set_detect_anomaly(True) import torch.distributed as dist from data_utils import RandomDataset from model_utils import format_numel_str, get_model_numel @@ -251,6 +253,7 @@ def main(): use_fp8=args.use_fp8, fp8_communication=args.use_fp8_comm, scheduler_nodes=scheduler_nodes, + make_vocab_size_divisible_by=1, **hybrid_kwargs, ) elif args.plugin == "3d_cpu": diff --git a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py index 1e8f1392e..4225da802 100644 --- a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py +++ b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py @@ -926,5 +926,6 @@ def test_pp(): ) +# python -m pytest -s tests/test_pipeline/test_schedule/test_zerobubble_pp.py if __name__ == "__main__": test_pp() diff --git a/tests/test_pipeline/test_schedule/zbv_poc.py b/tests/test_pipeline/test_schedule/zbv_poc.py new file mode 100644 index 000000000..6280990a9 --- /dev/null +++ b/tests/test_pipeline/test_schedule/zbv_poc.py @@ -0,0 +1,628 @@ +import gc +import time +from copy import deepcopy + +import torch +import torch.nn as nn +from torch.testing import assert_close + + +def get_model_numel(model): + return sum(p.numel() for p in model.parameters()) / 1024**2 + + +# Step1: dx = w*dy +def backward_b(loss, x, model): + torch.autograd.backward(loss, inputs=x, retain_graph=True) + + +# Step2: dummy dw = x*dy +def backward_w(loss, model): + torch.autograd.backward(loss, inputs=list(model.parameters())) + + +def test_double_dx_dw_split_nsync(): + device = "cuda:0" + model = nn.Linear(4096, 4096, bias=None).to(device=device) + # print(f"model numel {get_model_numel(model)}") # 4GB + x1 = torch.rand(4096, 4096).to(device=device) + x2 = torch.rand(4096, 4096).to(device=device) + ref_model = deepcopy(model) + ref_x1 = x1.clone() + ref_x2 = x1.clone() + + # first step + x1.requires_grad_() + x2.requires_grad_() + ref_x1.requires_grad_() + ref_x2.requires_grad_() + + # loss for dx_dw bwd + loss1 = model(x1).sum() + loss2 = model(x2).sum() + + # loss for common bwd + ref_loss1 = ref_model(ref_x1).sum() + ref_loss2 = ref_model(ref_x2).sum() + + # dx1 + torch.cuda.synchronize() + bwd_b_start_time = time.time() + backward_b(loss1, x1, model) + bwd_b_end_time = time.time() + print(f"loss_1 bwd B runtime {bwd_b_end_time - bwd_b_start_time}") + + for p in model.parameters(): + assert p.grad is None + assert x1.grad is not None + + # dx2 + torch.cuda.synchronize() + bwd_b_start_time = time.time() + backward_b(loss2, x2, model) + bwd_b_end_time = time.time() + print(f"loss_2 bwd B runtime {bwd_b_end_time - bwd_b_start_time}") + + # dw1 + torch.cuda.synchronize() + bwd_w_start_time = time.time() + backward_w(loss1, model) + bwd_w_end_time = time.time() + print(f"loss_1 bwd W runtime {bwd_w_end_time - bwd_w_start_time}") + for p in model.parameters(): + assert p.grad is not None + + # common bwd 1 + torch.cuda.synchronize() + comm_bwd_start_time = time.time() + ref_loss1.backward() + comm_bwd_end_time = time.time() + print(f"loss_1 comm bwd runtime {comm_bwd_end_time - comm_bwd_start_time}") + + # # assert dx1 & dw1 == bwd 1 + # assert_close(x1.grad, ref_x1.grad) + # for p1, p2 in zip(model.parameters(), ref_model.parameters()): + # assert_close(p1, p2) + # assert_close(p1.grad, p2.grad) + + # dw2 + torch.cuda.synchronize() + bwd_w_start_time = time.time() + backward_w(loss2, model) + bwd_w_end_time = time.time() + print(f"loss_2 bwd W runtime {bwd_w_end_time - bwd_w_start_time}") + + # common bwd 2 + torch.cuda.synchronize() + comm_bwd_start_time = time.time() + ref_loss2.backward() + comm_bwd_end_time = time.time() + print(f"loss_2 comm bwd runtime {comm_bwd_end_time - comm_bwd_start_time}") + + # # assert dx2 & dw2 == bwd 2 + # assert_close(x2.grad, ref_x2.grad) + # for p1, p2 in zip(model.parameters(), ref_model.parameters()): + # print(f"bwd2:\n p1 {p1.grad},\n p2 {p2.grad}\n") + # assert_close(p1, p2) + # assert_close(p1.grad, p2.grad) + + +def test_double_dx_dw_split_sync(): + device = "cuda:0" + model = nn.Linear(8, 8, bias=None).to(device=device) + print(f"model size {get_model_numel(model)} ") # 4GB + x1 = torch.rand(8, 8).to(device=device) + x2 = torch.rand(8, 8).to(device=device) + + # x1 = torch.ones(8, 8).to(device=device) + # x2 = torch.ones(8, 8).to(device=device) + + ref_model = deepcopy(model) + ref_x1 = x1.clone() + ref_x2 = x2.clone() + + x1.requires_grad_() + x2.requires_grad_() + ref_x1.requires_grad_() + ref_x2.requires_grad_() + + ############ + # step1: + ############ + + # loss1 + loss1 = model(x1).sum() + + # ref_loss1 + ref_model(ref_x1).sum() + + # dx1 + backward_b(loss1, x1, model) + for p in model.parameters(): + assert p.grad is None + assert x1.grad is not None + + # dw1 + backward_w(loss1, model) + for p in model.parameters(): + assert p.grad is not None + + # common bwd 1 + # ref_loss1.backward() + + # assert dx1 & dw1 == bwd 1 + assert_close(x1.grad, ref_x1.grad) + for p1, p2 in zip(model.parameters(), ref_model.parameters()): + assert_close(p1, p2) + assert_close(p1.grad, p2.grad) + + ############ + # step2: + ############ + + # loss2 + loss2 = model(x2).sum() + + # ref_loss2 + ref_loss2 = ref_model(ref_x2).sum() + + for p1, p2 in zip(model.parameters(), ref_model.parameters()): + print(f"bwd2:\n p1 {p1.grad},\n p2 {p2.grad}\n") + assert_close(p1, p2) + assert_close(p1.grad, p2.grad) + + # dx2 + backward_b(loss2, x2, model) + + # dw2 + backward_w(loss2, model) + + # common bwd 2 + ref_loss2.backward() + + # assert dx2 & dw2 == bwd 2 + assert_close(x2.grad, ref_x2.grad) + for p1, p2 in zip(model.parameters(), ref_model.parameters()): + print(f"bwd2:\n p1 {p1.grad},\n p2 {p2.grad}\n") + assert_close(p1, p2) + assert_close(p1.grad, p2.grad) + + +def deallocate_output_tensor(out): + """Pseudo-deallocate (i.e., set to scalar) the output tensor's '.data' field. + + This method should be called right after the output tensor has been + sent to the next pipeline stage. At this point, the output tensor is + only useful for its '.grad_fn' field, and not its '.data'. + """ + assert isinstance(out, torch.Tensor), "expected Tensor, found %s." % type(out).__name__ + assert out._base is None, "counter-productive to free a view of another tensor." + out.data = torch.empty( + (1,), + device=out.device, + dtype=out.dtype, + ) + + +IN_DIM = 8192 +OUT_DIM = 8192 +NUM_LAYER = 3 + + +class MlpModel(nn.Module): + def __init__(self): + super().__init__() + self.layers = nn.ModuleList([nn.Linear(IN_DIM, OUT_DIM, bias=None) for _ in range(NUM_LAYER)]) + + def forward(self, x): + for layer in self.layers: + x = layer(x) + return x + + +class Attention(nn.Module): + def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.0, proj_drop=0.0, with_qkv=True): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim**-0.5 + self.with_qkv = with_qkv + if self.with_qkv: + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + self.attn_drop = nn.Dropout(attn_drop) + + def forward(self, x): + B, N, C = x.shape + if self.with_qkv: + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] + else: + qkv = x.reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) + q, k, v = qkv, qkv, qkv + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + if self.with_qkv: + x = self.proj(x) + x = self.proj_drop(x) + return x + + +def mem_dx_dw(): + device = "cuda:0" + # model = nn.Linear(IN_DIM, OUT_DIM, bias=None).to(device=device) + print(f"Before init Model: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") + model = MlpModel().to(device=device) + print(f"model numel {get_model_numel(model)}") # 4GB + print(f"After init Model: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") + + print(f"Before init x1&2&3: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") + x1 = torch.rand(IN_DIM, OUT_DIM).to(device=device) + x2 = torch.rand(IN_DIM, OUT_DIM).to(device=device) + x3 = torch.rand(IN_DIM, OUT_DIM).to(device=device) + + x1.requires_grad_() + x2.requires_grad_() + x3.requires_grad_() + print(f"After init x1&2&3: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") + + ############ + # step1: + ############ + print(f"\nStep1") + + # loss1 + print(f"Before Fwd x1: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") + y1 = model(x1) + print(f"After Fwd x1: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") + + print(f"Before loss1: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") + loss1 = y1.sum() + print(f"After loss1: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") + + # dx1 + backward_b(loss1, x1, model) + + # dw1 + backward_w(loss1, model) + + deallocate_output_tensor(x1) + deallocate_output_tensor(y1) + # del x1 + # del y1 + print(f"After del x1&y1: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") + + # print(f"\n Step1:collect:{gc.collect()}") + # print(f"object: {gc.get_objects()}") + # print(f"garbage: {gc.garbage}") + + ############ + # step2: + ############ + print(f"\nStep2") + + # loss2 + y2 = model(x2) + loss2 = y2.sum() + + # dx2 + backward_b(loss2, x2, model) + + # dw2 + backward_w(loss2, model) + deallocate_output_tensor(x2) + deallocate_output_tensor(y2) + # del x2 + # del y2 + print(f"After del x2&y2: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") + + print(f"\n Step2:collect:{gc.collect()}") + # print(f"object: {gc.get_objects()}") + print(f"garbage: {gc.garbage}") + + ############ + # step3: + ############ + + print(f"\nStep3") + + # loss3 + y3 = model(x3) + loss3 = y3.sum() + + # dx2 + backward_b(loss3, x3, model) + + # dw2 + backward_w(loss3, model) + + deallocate_output_tensor(x3) + deallocate_output_tensor(y3) + # del x3 + # del y3 + + print(f"After del x3&y3: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") + + print(f"\n Step3:collect:{gc.collect()}") + # print(f"object: {gc.get_objects()}") + print(f"garbage: {gc.garbage}") + + +# del activation +def activation_dx_dw(): + device = "cuda:0" + # model = nn.Linear(IN_DIM, OUT_DIM, bias=None).to(device=device) + print(f"Before init Model: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") + model = MlpModel().to(device=device) + x1 = torch.rand(IN_DIM, OUT_DIM).to(device=device) + x2 = torch.rand(IN_DIM, OUT_DIM).to(device=device) + x3 = torch.rand(IN_DIM, OUT_DIM).to(device=device) + + x1.requires_grad_() + x2.requires_grad_() + x3.requires_grad_() + print(f"After init Model, x1,x2,x3: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") + + activations = {} + + def register_hooks(module): + def activation_hook(module, input, output): + activations[f"{module.__class__.__name__}_{id(module)}"] = output.detach() + + def bwd_hook(module, grad_input, grad_output): + del activations[f"{module.__class__.__name__}_{id(module)}"] + + module.register_forward_hook(activation_hook) + module.register_backward_hook(bwd_hook) + + model.apply(register_hooks) + + ############ + # step1: + ############ + print(f"\nStep1") + + # loss1 + loss1 = model(x1).sum() + + # dx1 + backward_b(loss1, x1, model) + + # dw1 + backward_w(loss1, model) + + del loss1, x1 + print(f"After del x1&y1: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") + + ############ + # step2: + ############ + print(f"\nStep2") + + # loss2 + loss2 = model(x2).sum() + + # dx2 + backward_b(loss2, x2, model) + + # dw2 + backward_w(loss2, model) + + # deallocate_output_tensor(x2) + # deallocate_output_tensor(loss2) + del x2, loss2 + print(f"After del x2&y2: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") + + ############ + # step3: + ############ + print(f"\nStep3") + + # loss3 + loss3 = model(x3).sum() + + # dx2 + backward_b(loss3, x3, model) + + # dw2 + backward_w(loss3, model) + + del x3, loss3 + + print(f"After del x3&y3: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") + + +# text dx dw in model chunk +def model_chunk_dx_dw(): + device = "cuda:0" + num_layers = 4 + print(f"Before init Model: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") + model = MlpModel(in_dim=4096, out_dim=4096, num_layers=num_layers).to(device=device) + x = torch.rand(4096, 4096).to(device=device) + x.requires_grad_() + + model_chunk_0 = torch.nn.ModuleList() # for layer 1 & 2 + model_chunk_1 = torch.nn.ModuleList() # for layer 3 & 4 + + for idx, sub_model in enumerate(model.layers): + if idx < 2: + model_chunk_0.append(sub_model).cuda() + else: + model_chunk_1.append(sub_model).cuda() + + print(f"After init Model & input: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") + + # Step1:chunk 0 fwd + activation = dict() # layer_id: activation + out = x + for i in range(len(model_chunk_0)): + layer = model_chunk_0[i] + activation[i] = layer(out) + print(f"After chunk0 fwd: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") + # Step2:chunk 1 fwd + for i in range(len(model_chunk_1)): + layer = model_chunk_0[i] + activation[i + 2] = layer(out) + print(f"After chunk1 fwd: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") + + # Step3:chunk 1 bwd b: dx=w*dy & bwd w:dw=x*dy + # visit layer reversely + for i in range(len(model_chunk_1) - 1, -1, -1): + layer = model_chunk_1[i] + global_layer_idx = i + 2 + prev_global_layer_idx = i + 1 if i + 1 > 0 else None + i + 3 if i + 3 < 4 else None + + # bwd b + if global_layer_idx == num_layers - 1: # last layer in last chunk; calculate loss + loss = activation[global_layer_idx].sum() + x = activation[prev_global_layer_idx] + backward_b(loss, x, layer) + else: + loss = activation[global_layer_idx].sum() + x = activation[prev_global_layer_idx] + backward_b(loss, x, layer) + + # bwd w + backward_w(loss, layer) + + +def test_dx_dw_linear_benchmark(): + device = "cuda:0" + model = nn.Linear(4096, 4096, bias=None).to(device=device) + # print(f"model numel {get_model_numel(model)}") # 4GB + x1 = torch.rand(4096, 4096).to(device=device) + # x2 = torch.rand(4096, 4096).to(device=device) + ref_model = deepcopy(model) + ref_x1 = x1.clone() + # ref_x2 = x1.clone() + + # first step + x1.requires_grad_() + # x2.requires_grad_() + ref_x1.requires_grad_() + # ref_x2.requires_grad_() + + # loss for dx_dw bwd + loss1 = model(x1).sum() + # loss2 = model(x2).sum() + + # loss for common bwd + ref_model(ref_x1).sum() + # ref_loss2 = ref_model(ref_x2).sum() + + with torch.profiler.profile( + activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA], + # schedule=torch.profiler.schedule(wait=1, warmup=1, active=3, repeat=1), + on_trace_ready=torch.profiler.tensorboard_trace_handler( + f"/home/nvme-share/home/duanjunwen/ColossalAI/tests/test_pipeline/test_schedule" + ), + record_shapes=True, + profile_memory=True, + with_stack=True, + with_flops=True, + ) as prof: + # dx1 + torch.cuda.synchronize() + bwd_b_start_time = time.time() + backward_b(loss1, x1, model) + bwd_b_end_time = time.time() + print(f"loss_1 bwd B runtime {bwd_b_end_time - bwd_b_start_time}") + + for p in model.parameters(): + assert p.grad is None + assert x1.grad is not None + + # dw1 + torch.cuda.synchronize() + bwd_w_start_time = time.time() + backward_w(loss1, model) + bwd_w_end_time = time.time() + print(f"loss_1 bwd W runtime {bwd_w_end_time - bwd_w_start_time}") + for p in model.parameters(): + assert p.grad is not None + + # # common bwd 1 + # torch.cuda.synchronize() + # comm_bwd_start_time = time.time() + # ref_loss1.backward() + # comm_bwd_end_time = time.time() + # print(f"loss_1 comm bwd runtime {comm_bwd_end_time - comm_bwd_start_time}") + + +def test_dx_dw_attn_benchmark(): + device = "cuda:0" + model = Attention(dim=4096).to(device=device) + # print(f"model numel {get_model_numel(model)}") # 4GB + x1 = torch.rand(1, 256, 4096).to(device=device) + # x2 = torch.rand(1, 256, 4096).to(device=device) + ref_model = deepcopy(model) + ref_x1 = x1.clone() + # ref_x2 = x1.clone() + + # first step + x1.requires_grad_() + # x2.requires_grad_() + ref_x1.requires_grad_() + # ref_x2.requires_grad_() + + # loss for dx_dw bwd + loss1 = model(x1).sum() + # loss2 = model(x2).sum() + + # loss for common bwd + ref_model(ref_x1).sum() + # ref_loss2 = ref_model(ref_x2).sum() + + with torch.profiler.profile( + activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA], + # schedule=torch.profiler.schedule(wait=1, warmup=1, active=3, repeat=1), + on_trace_ready=torch.profiler.tensorboard_trace_handler( + f"/home/nvme-share/home/duanjunwen/ColossalAI/tests/test_pipeline/test_schedule" + ), + record_shapes=True, + profile_memory=True, + with_stack=True, + with_flops=True, + ) as prof: + # dx1 + torch.cuda.synchronize() + bwd_b_start_time = time.time() + backward_b(loss1, x1, model) + bwd_b_end_time = time.time() + print(f"loss_1 bwd B runtime {bwd_b_end_time - bwd_b_start_time}") + + for p in model.parameters(): + assert p.grad is None + assert x1.grad is not None + + # dw1 + torch.cuda.synchronize() + bwd_w_start_time = time.time() + backward_w(loss1, model) + bwd_w_end_time = time.time() + print(f"loss_1 bwd W runtime {bwd_w_end_time - bwd_w_start_time}") + for p in model.parameters(): + assert p.grad is not None + + # # common bwd 1 + # torch.cuda.synchronize() + # comm_bwd_start_time = time.time() + # ref_loss1.backward() + # comm_bwd_end_time = time.time() + # print(f"loss_1 comm bwd runtime {comm_bwd_end_time - comm_bwd_start_time}") + + +if __name__ == "__main__": + # test_dx_dw_split() + # test_double_dx_dw_split_nsync() + # test_double_dx_dw_split_sync() + # mem_dx_dw() + # activation_dx_dw() + # test_dx_dw_linear_benchmark() + test_dx_dw_attn_benchmark()