[feat] Linear1D_COL/ROW support zbv WeightGradStore;

pull/6083/head
duanjunwen 2024-10-14 07:02:43 +00:00
parent 0ca16d5cbe
commit cfade4c36d
7 changed files with 820 additions and 28 deletions

View File

@ -11,6 +11,7 @@ from colossalai.interface import OptimizerWrapper
from colossalai.pipeline.p2p import PipelineP2PCommunication
from colossalai.pipeline.schedule.v_schedule import ScheduledNode
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.pipeline.weight_grad_store import WeightGradStore
from ._utils import (
clone,
@ -650,10 +651,10 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
# Do not release_tensor_data loss, release_tensor_data other output_obj;
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
self.output_tensors[model_chunk_id].append(output_obj)
self.output_tensors_dw[model_chunk_id].append(output_obj)
# self.output_tensors_dw[model_chunk_id].append(output_obj)
else:
self.output_tensors[model_chunk_id].append(output_obj)
self.output_tensors_dw[model_chunk_id].append(output_obj)
# self.output_tensors_dw[model_chunk_id].append(output_obj)
# add output to send_fwd_buffer
if model_chunk_id == 0: # chunk 0
@ -705,13 +706,13 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
input_obj = self.input_tensors[model_chunk_id].pop(0)
output_obj = self.output_tensors[model_chunk_id].pop(0)
# save output_tensor_grad for dw
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
# we save loss here
self.output_tensors_grad_dw[model_chunk_id].append(output_obj)
else:
# we save output_tensor_grad here
self.output_tensors_grad_dw[model_chunk_id].append(output_tensor_grad)
# # save output_tensor_grad for dw
# if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
# # we save loss here
# self.output_tensors_grad_dw[model_chunk_id].append(output_obj)
# else:
# # we save output_tensor_grad here
# self.output_tensors_grad_dw[model_chunk_id].append(output_tensor_grad)
# Step2: bwd step
input_object_grad = self.backward_b_step(
@ -738,6 +739,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
# send to next
else:
self.send_backward_buffer[model_chunk_id].append(input_object_grad)
WeightGradStore.flush(chunk=model_chunk_id)
def schedule_w(
self,
@ -757,16 +759,18 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
"""
# get y & dy from buffer
output_obj = self.output_tensors_dw[model_chunk_id].pop(0)
output_obj_grad = self.output_tensors_grad_dw[model_chunk_id].pop(0)
# output_obj = self.output_tensors_dw[model_chunk_id].pop(0)
# output_obj_grad = self.output_tensors_grad_dw[model_chunk_id].pop(0)
self.backward_w_step(
model_chunk=model_chunk,
model_chunk_id=model_chunk_id,
optimizer=optimizer,
output_obj=output_obj,
output_obj_grad=output_obj_grad,
)
WeightGradStore.pop(chunk=model_chunk_id)
# self.backward_w_step(
# model_chunk=model_chunk,
# model_chunk_id=model_chunk_id,
# optimizer=optimizer,
# output_obj=output_obj,
# output_obj_grad=output_obj_grad,
# )
def run_forward_only(
self,

View File

@ -0,0 +1,106 @@
import queue
# from megatron import get_args
# from megatron.core import parallel_state
# from megatron.core.distributed.finalize_model_grads import _allreduce_embedding_grads
# from megatron.core.utils import get_model_config, get_attr_wrapped_model
class WeightGradStore:
cache = []
weight_grad_queue = [queue.Queue(), queue.Queue()]
@classmethod
def put(cls, total_input, grad_output, weight, func):
# func(total_input, grad_output, weight.main_grad)
cls.cache.append((total_input, grad_output, weight, func))
@classmethod
def flush(cls, chunk=0):
cls.weight_grad_queue[chunk].put(cls.cache)
cls.cache = []
@classmethod
def pop(cls, chunk=0):
if cls.weight_grad_queue[chunk].qsize() > 0:
stored_grads = cls.weight_grad_queue[chunk].get()
for total_input, grad_output, weight, func in stored_grads:
if weight.grad is not None:
func(total_input, grad_output, weight.grad)
# for first bwd; weight.grad is None, assign grad_weight to weight.grad
else:
grad_weight = func(total_input, grad_output)
weight.grad = grad_weight
else:
raise Exception("Pop empty queue.")
# @classmethod
# def clear(cls, model, chunk=0):
# weight_grad_tasks = []
# while cls.weight_grad_queue[chunk].qsize() > 0:
# stored_grads = cls.weight_grad_queue[chunk].get()
# if len(weight_grad_tasks) == 0:
# for _ in stored_grads:
# weight_grad_tasks.append([])
# else:
# assert len(weight_grad_tasks) == len(stored_grads)
# for i, task in enumerate(stored_grads):
# weight_grad_tasks[i].append(task)
# weight_params = []
# handles = []
# if get_args().overlap_grad_reduce:
# handles += model.async_reduce_grad()
# output_layer_weight = None
# if parallel_state.is_pipeline_last_stage():
# assert len(weight_grad_tasks) > 0
# output_layer_grads = weight_grad_tasks[0]
# for j in range(len(output_layer_grads)):
# total_input, grad_output, weight, func = output_layer_grads[j]
# if output_layer_weight is None:
# output_layer_weight = weight
# assert output_layer_weight is weight
# func(total_input, grad_output, weight.main_grad)
# output_layer_grads[j] = None # release memory
# weight_grad_tasks = weight_grad_tasks[1:]
# if get_args().overlap_grad_reduce:
# handles += model.async_reduce_grad(output_layer_weight)
# if parallel_state.is_pipeline_first_stage() or parallel_state.is_pipeline_last_stage():
# model_module = get_attr_wrapped_model(model, 'pre_process', return_model_obj=True)
# if model_module.share_embeddings_and_output_weights:
# # if share_embeddings_and_output_weights, wait all-reduce for embeddings
# for handle in handles:
# if handle is not None:
# handle.wait()
# handles = []
# config = get_model_config(model)
# # Do async all-reduce for embedding grads firstly, so that the rank 0 won't
# # be blocked
# embedding_handles = _allreduce_embedding_grads([model], config, async_op=True)
# handles += embedding_handles
# for i in range(len(weight_grad_tasks)):
# tasks = weight_grad_tasks[i]
# param = None
# for j in range(len(tasks)):
# total_input, grad_output, weight, func = tasks[j]
# if param is None:
# param = weight
# assert param is weight
# assert not (weight is output_layer_weight)
# func(total_input, grad_output, weight.main_grad)
# tasks[j] = None # release memory
# weight_params.append(param)
# if get_args().overlap_grad_reduce:
# # All-reduce param grad here
# handles += model.async_reduce_grad(param)
# weight_grad_tasks[i] = None # release memory
# # timers('wait_all_reduce', log_level=1).start(barrier=False)
# for handle in embedding_handles:
# if handle is not None:
# handle.wait()
# # timers('wait_all_reduce').stop()

View File

@ -1,7 +1,11 @@
import functools
import torch
import torch.distributed as dist
import torch.nn.functional as F
from colossalai.pipeline.weight_grad_store import WeightGradStore
from .utils import is_share_sp_tp
try:
@ -125,12 +129,13 @@ class LinearWithAsyncCommunication(torch.autograd.Function):
"""
@staticmethod
def forward(ctx, input_, weight, bias, process_group, async_grad_allreduce, fp8_communication=False):
def forward(ctx, input_, weight, bias, process_group, async_grad_allreduce, fp8_communication=False, use_zbv=True):
ctx.save_for_backward(input_, weight, bias)
ctx.use_bias = bias is not None
ctx.process_group = process_group
ctx.async_grad_allreduce = async_grad_allreduce
ctx.fp8_communication = fp8_communication
ctx.use_zbv = use_zbv
if bias is not None:
output = F.linear(input_, weight, bias)
else:
@ -143,6 +148,14 @@ class LinearWithAsyncCommunication(torch.autograd.Function):
input, weight, bias = ctx.saved_tensors
use_bias = ctx.use_bias
fp8_communication = ctx.fp8_communication
use_zbv = ctx.use_zbv
def execute_w_pass_grad_accum(_input_, _grad_output_, _weight_main_grad_, wgrad_gemm_accum_func=None):
wgrad_gemm_accum_func(_input_, _grad_output_, _weight_main_grad_)
def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_func=None):
# _grad_output_.t().matmul(_input_)
return wgrad_gemm_func(_grad_output_.t(), _input_)
# In order to be hooked into Gemini's '__torch_function__', adding a view operation to bias.
if use_bias:
@ -167,22 +180,60 @@ class LinearWithAsyncCommunication(torch.autograd.Function):
if _grad_accum_fusion_available and weight.grad is not None:
grad = weight.grad
if grad.dtype == torch.float32:
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(total_input, grad_output, grad)
grad_weight = None
elif grad.dtype == torch.float16:
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(total_input, grad_output, grad)
if use_zbv:
# TODO: append input, grad_output_, weight, grad func to WeightGradStore
if grad.dtype == torch.float32:
WeightGradStore.put(
total_input,
grad_output,
weight,
functools.partial(
execute_w_pass_grad_accum,
wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32,
),
)
grad_weight = None
elif grad.dtype in (torch.float16, torch.bfloat16):
WeightGradStore.put(
total_input,
grad_output,
weight,
functools.partial(
execute_w_pass_grad_accum,
wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16,
),
)
grad_weight = None
else:
raise RuntimeError("Unsupported gradient type for gradient accumulation fusion")
else:
if grad.dtype == torch.float32:
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(total_input, grad_output, grad)
grad_weight = None
elif grad.dtype == torch.float16:
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(total_input, grad_output, grad)
grad_weight = None
else:
grad_weight = grad_output.t().matmul(total_input)
else:
if use_zbv:
WeightGradStore.put(
total_input,
grad_output,
weight,
functools.partial(
execute_w_pass,
wgrad_gemm_func=torch.matmul,
),
)
grad_weight = None
else:
grad_weight = grad_output.t().matmul(total_input)
else:
grad_weight = grad_output.t().matmul(total_input)
grad_bias = grad_output.sum(dim=0) if use_bias else None
if ctx.async_grad_allreduce and not fp8_communication:
handle.wait()
return grad_input, grad_weight, grad_bias, None, None, None, None

View File

@ -201,7 +201,6 @@ class Linear1D_Col(ParallelModule):
# Matrix multiply.
bias = self.bias if not self.skip_bias_add else None
if self.seq_parallel_mode == "split_gather":
input_parallel = gather_forward_reducescatter_backward(
input_parallel, self.process_group, self.seq_parallel_dim, fp8_communication=self.fp8_communication

View File

@ -5,6 +5,8 @@ import warnings
from contextlib import nullcontext
import torch
torch.autograd.set_detect_anomaly(True)
import torch.distributed as dist
from data_utils import RandomDataset
from model_utils import format_numel_str, get_model_numel
@ -251,6 +253,7 @@ def main():
use_fp8=args.use_fp8,
fp8_communication=args.use_fp8_comm,
scheduler_nodes=scheduler_nodes,
make_vocab_size_divisible_by=1,
**hybrid_kwargs,
)
elif args.plugin == "3d_cpu":

View File

@ -926,5 +926,6 @@ def test_pp():
)
# python -m pytest -s tests/test_pipeline/test_schedule/test_zerobubble_pp.py
if __name__ == "__main__":
test_pp()

View File

@ -0,0 +1,628 @@
import gc
import time
from copy import deepcopy
import torch
import torch.nn as nn
from torch.testing import assert_close
def get_model_numel(model):
return sum(p.numel() for p in model.parameters()) / 1024**2
# Step1: dx = w*dy
def backward_b(loss, x, model):
torch.autograd.backward(loss, inputs=x, retain_graph=True)
# Step2: dummy dw = x*dy
def backward_w(loss, model):
torch.autograd.backward(loss, inputs=list(model.parameters()))
def test_double_dx_dw_split_nsync():
device = "cuda:0"
model = nn.Linear(4096, 4096, bias=None).to(device=device)
# print(f"model numel {get_model_numel(model)}") # 4GB
x1 = torch.rand(4096, 4096).to(device=device)
x2 = torch.rand(4096, 4096).to(device=device)
ref_model = deepcopy(model)
ref_x1 = x1.clone()
ref_x2 = x1.clone()
# first step
x1.requires_grad_()
x2.requires_grad_()
ref_x1.requires_grad_()
ref_x2.requires_grad_()
# loss for dx_dw bwd
loss1 = model(x1).sum()
loss2 = model(x2).sum()
# loss for common bwd
ref_loss1 = ref_model(ref_x1).sum()
ref_loss2 = ref_model(ref_x2).sum()
# dx1
torch.cuda.synchronize()
bwd_b_start_time = time.time()
backward_b(loss1, x1, model)
bwd_b_end_time = time.time()
print(f"loss_1 bwd B runtime {bwd_b_end_time - bwd_b_start_time}")
for p in model.parameters():
assert p.grad is None
assert x1.grad is not None
# dx2
torch.cuda.synchronize()
bwd_b_start_time = time.time()
backward_b(loss2, x2, model)
bwd_b_end_time = time.time()
print(f"loss_2 bwd B runtime {bwd_b_end_time - bwd_b_start_time}")
# dw1
torch.cuda.synchronize()
bwd_w_start_time = time.time()
backward_w(loss1, model)
bwd_w_end_time = time.time()
print(f"loss_1 bwd W runtime {bwd_w_end_time - bwd_w_start_time}")
for p in model.parameters():
assert p.grad is not None
# common bwd 1
torch.cuda.synchronize()
comm_bwd_start_time = time.time()
ref_loss1.backward()
comm_bwd_end_time = time.time()
print(f"loss_1 comm bwd runtime {comm_bwd_end_time - comm_bwd_start_time}")
# # assert dx1 & dw1 == bwd 1
# assert_close(x1.grad, ref_x1.grad)
# for p1, p2 in zip(model.parameters(), ref_model.parameters()):
# assert_close(p1, p2)
# assert_close(p1.grad, p2.grad)
# dw2
torch.cuda.synchronize()
bwd_w_start_time = time.time()
backward_w(loss2, model)
bwd_w_end_time = time.time()
print(f"loss_2 bwd W runtime {bwd_w_end_time - bwd_w_start_time}")
# common bwd 2
torch.cuda.synchronize()
comm_bwd_start_time = time.time()
ref_loss2.backward()
comm_bwd_end_time = time.time()
print(f"loss_2 comm bwd runtime {comm_bwd_end_time - comm_bwd_start_time}")
# # assert dx2 & dw2 == bwd 2
# assert_close(x2.grad, ref_x2.grad)
# for p1, p2 in zip(model.parameters(), ref_model.parameters()):
# print(f"bwd2:\n p1 {p1.grad},\n p2 {p2.grad}\n")
# assert_close(p1, p2)
# assert_close(p1.grad, p2.grad)
def test_double_dx_dw_split_sync():
device = "cuda:0"
model = nn.Linear(8, 8, bias=None).to(device=device)
print(f"model size {get_model_numel(model)} ") # 4GB
x1 = torch.rand(8, 8).to(device=device)
x2 = torch.rand(8, 8).to(device=device)
# x1 = torch.ones(8, 8).to(device=device)
# x2 = torch.ones(8, 8).to(device=device)
ref_model = deepcopy(model)
ref_x1 = x1.clone()
ref_x2 = x2.clone()
x1.requires_grad_()
x2.requires_grad_()
ref_x1.requires_grad_()
ref_x2.requires_grad_()
############
# step1:
############
# loss1
loss1 = model(x1).sum()
# ref_loss1
ref_model(ref_x1).sum()
# dx1
backward_b(loss1, x1, model)
for p in model.parameters():
assert p.grad is None
assert x1.grad is not None
# dw1
backward_w(loss1, model)
for p in model.parameters():
assert p.grad is not None
# common bwd 1
# ref_loss1.backward()
# assert dx1 & dw1 == bwd 1
assert_close(x1.grad, ref_x1.grad)
for p1, p2 in zip(model.parameters(), ref_model.parameters()):
assert_close(p1, p2)
assert_close(p1.grad, p2.grad)
############
# step2:
############
# loss2
loss2 = model(x2).sum()
# ref_loss2
ref_loss2 = ref_model(ref_x2).sum()
for p1, p2 in zip(model.parameters(), ref_model.parameters()):
print(f"bwd2:\n p1 {p1.grad},\n p2 {p2.grad}\n")
assert_close(p1, p2)
assert_close(p1.grad, p2.grad)
# dx2
backward_b(loss2, x2, model)
# dw2
backward_w(loss2, model)
# common bwd 2
ref_loss2.backward()
# assert dx2 & dw2 == bwd 2
assert_close(x2.grad, ref_x2.grad)
for p1, p2 in zip(model.parameters(), ref_model.parameters()):
print(f"bwd2:\n p1 {p1.grad},\n p2 {p2.grad}\n")
assert_close(p1, p2)
assert_close(p1.grad, p2.grad)
def deallocate_output_tensor(out):
"""Pseudo-deallocate (i.e., set to scalar) the output tensor's '.data' field.
This method should be called right after the output tensor has been
sent to the next pipeline stage. At this point, the output tensor is
only useful for its '.grad_fn' field, and not its '.data'.
"""
assert isinstance(out, torch.Tensor), "expected Tensor, found %s." % type(out).__name__
assert out._base is None, "counter-productive to free a view of another tensor."
out.data = torch.empty(
(1,),
device=out.device,
dtype=out.dtype,
)
IN_DIM = 8192
OUT_DIM = 8192
NUM_LAYER = 3
class MlpModel(nn.Module):
def __init__(self):
super().__init__()
self.layers = nn.ModuleList([nn.Linear(IN_DIM, OUT_DIM, bias=None) for _ in range(NUM_LAYER)])
def forward(self, x):
for layer in self.layers:
x = layer(x)
return x
class Attention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.0, proj_drop=0.0, with_qkv=True):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim**-0.5
self.with_qkv = with_qkv
if self.with_qkv:
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
self.attn_drop = nn.Dropout(attn_drop)
def forward(self, x):
B, N, C = x.shape
if self.with_qkv:
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
else:
qkv = x.reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
q, k, v = qkv, qkv, qkv
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
if self.with_qkv:
x = self.proj(x)
x = self.proj_drop(x)
return x
def mem_dx_dw():
device = "cuda:0"
# model = nn.Linear(IN_DIM, OUT_DIM, bias=None).to(device=device)
print(f"Before init Model: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
model = MlpModel().to(device=device)
print(f"model numel {get_model_numel(model)}") # 4GB
print(f"After init Model: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
print(f"Before init x1&2&3: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
x1 = torch.rand(IN_DIM, OUT_DIM).to(device=device)
x2 = torch.rand(IN_DIM, OUT_DIM).to(device=device)
x3 = torch.rand(IN_DIM, OUT_DIM).to(device=device)
x1.requires_grad_()
x2.requires_grad_()
x3.requires_grad_()
print(f"After init x1&2&3: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
############
# step1:
############
print(f"\nStep1")
# loss1
print(f"Before Fwd x1: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
y1 = model(x1)
print(f"After Fwd x1: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
print(f"Before loss1: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
loss1 = y1.sum()
print(f"After loss1: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
# dx1
backward_b(loss1, x1, model)
# dw1
backward_w(loss1, model)
deallocate_output_tensor(x1)
deallocate_output_tensor(y1)
# del x1
# del y1
print(f"After del x1&y1: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
# print(f"\n Step1:collect:{gc.collect()}")
# print(f"object: {gc.get_objects()}")
# print(f"garbage: {gc.garbage}")
############
# step2:
############
print(f"\nStep2")
# loss2
y2 = model(x2)
loss2 = y2.sum()
# dx2
backward_b(loss2, x2, model)
# dw2
backward_w(loss2, model)
deallocate_output_tensor(x2)
deallocate_output_tensor(y2)
# del x2
# del y2
print(f"After del x2&y2: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
print(f"\n Step2:collect:{gc.collect()}")
# print(f"object: {gc.get_objects()}")
print(f"garbage: {gc.garbage}")
############
# step3:
############
print(f"\nStep3")
# loss3
y3 = model(x3)
loss3 = y3.sum()
# dx2
backward_b(loss3, x3, model)
# dw2
backward_w(loss3, model)
deallocate_output_tensor(x3)
deallocate_output_tensor(y3)
# del x3
# del y3
print(f"After del x3&y3: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
print(f"\n Step3:collect:{gc.collect()}")
# print(f"object: {gc.get_objects()}")
print(f"garbage: {gc.garbage}")
# del activation
def activation_dx_dw():
device = "cuda:0"
# model = nn.Linear(IN_DIM, OUT_DIM, bias=None).to(device=device)
print(f"Before init Model: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
model = MlpModel().to(device=device)
x1 = torch.rand(IN_DIM, OUT_DIM).to(device=device)
x2 = torch.rand(IN_DIM, OUT_DIM).to(device=device)
x3 = torch.rand(IN_DIM, OUT_DIM).to(device=device)
x1.requires_grad_()
x2.requires_grad_()
x3.requires_grad_()
print(f"After init Model, x1,x2,x3: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
activations = {}
def register_hooks(module):
def activation_hook(module, input, output):
activations[f"{module.__class__.__name__}_{id(module)}"] = output.detach()
def bwd_hook(module, grad_input, grad_output):
del activations[f"{module.__class__.__name__}_{id(module)}"]
module.register_forward_hook(activation_hook)
module.register_backward_hook(bwd_hook)
model.apply(register_hooks)
############
# step1:
############
print(f"\nStep1")
# loss1
loss1 = model(x1).sum()
# dx1
backward_b(loss1, x1, model)
# dw1
backward_w(loss1, model)
del loss1, x1
print(f"After del x1&y1: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
############
# step2:
############
print(f"\nStep2")
# loss2
loss2 = model(x2).sum()
# dx2
backward_b(loss2, x2, model)
# dw2
backward_w(loss2, model)
# deallocate_output_tensor(x2)
# deallocate_output_tensor(loss2)
del x2, loss2
print(f"After del x2&y2: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
############
# step3:
############
print(f"\nStep3")
# loss3
loss3 = model(x3).sum()
# dx2
backward_b(loss3, x3, model)
# dw2
backward_w(loss3, model)
del x3, loss3
print(f"After del x3&y3: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
# text dx dw in model chunk
def model_chunk_dx_dw():
device = "cuda:0"
num_layers = 4
print(f"Before init Model: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
model = MlpModel(in_dim=4096, out_dim=4096, num_layers=num_layers).to(device=device)
x = torch.rand(4096, 4096).to(device=device)
x.requires_grad_()
model_chunk_0 = torch.nn.ModuleList() # for layer 1 & 2
model_chunk_1 = torch.nn.ModuleList() # for layer 3 & 4
for idx, sub_model in enumerate(model.layers):
if idx < 2:
model_chunk_0.append(sub_model).cuda()
else:
model_chunk_1.append(sub_model).cuda()
print(f"After init Model & input: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
# Step1:chunk 0 fwd
activation = dict() # layer_id: activation
out = x
for i in range(len(model_chunk_0)):
layer = model_chunk_0[i]
activation[i] = layer(out)
print(f"After chunk0 fwd: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
# Step2:chunk 1 fwd
for i in range(len(model_chunk_1)):
layer = model_chunk_0[i]
activation[i + 2] = layer(out)
print(f"After chunk1 fwd: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
# Step3:chunk 1 bwd b: dx=w*dy & bwd w:dw=x*dy
# visit layer reversely
for i in range(len(model_chunk_1) - 1, -1, -1):
layer = model_chunk_1[i]
global_layer_idx = i + 2
prev_global_layer_idx = i + 1 if i + 1 > 0 else None
i + 3 if i + 3 < 4 else None
# bwd b
if global_layer_idx == num_layers - 1: # last layer in last chunk; calculate loss
loss = activation[global_layer_idx].sum()
x = activation[prev_global_layer_idx]
backward_b(loss, x, layer)
else:
loss = activation[global_layer_idx].sum()
x = activation[prev_global_layer_idx]
backward_b(loss, x, layer)
# bwd w
backward_w(loss, layer)
def test_dx_dw_linear_benchmark():
device = "cuda:0"
model = nn.Linear(4096, 4096, bias=None).to(device=device)
# print(f"model numel {get_model_numel(model)}") # 4GB
x1 = torch.rand(4096, 4096).to(device=device)
# x2 = torch.rand(4096, 4096).to(device=device)
ref_model = deepcopy(model)
ref_x1 = x1.clone()
# ref_x2 = x1.clone()
# first step
x1.requires_grad_()
# x2.requires_grad_()
ref_x1.requires_grad_()
# ref_x2.requires_grad_()
# loss for dx_dw bwd
loss1 = model(x1).sum()
# loss2 = model(x2).sum()
# loss for common bwd
ref_model(ref_x1).sum()
# ref_loss2 = ref_model(ref_x2).sum()
with torch.profiler.profile(
activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA],
# schedule=torch.profiler.schedule(wait=1, warmup=1, active=3, repeat=1),
on_trace_ready=torch.profiler.tensorboard_trace_handler(
f"/home/nvme-share/home/duanjunwen/ColossalAI/tests/test_pipeline/test_schedule"
),
record_shapes=True,
profile_memory=True,
with_stack=True,
with_flops=True,
) as prof:
# dx1
torch.cuda.synchronize()
bwd_b_start_time = time.time()
backward_b(loss1, x1, model)
bwd_b_end_time = time.time()
print(f"loss_1 bwd B runtime {bwd_b_end_time - bwd_b_start_time}")
for p in model.parameters():
assert p.grad is None
assert x1.grad is not None
# dw1
torch.cuda.synchronize()
bwd_w_start_time = time.time()
backward_w(loss1, model)
bwd_w_end_time = time.time()
print(f"loss_1 bwd W runtime {bwd_w_end_time - bwd_w_start_time}")
for p in model.parameters():
assert p.grad is not None
# # common bwd 1
# torch.cuda.synchronize()
# comm_bwd_start_time = time.time()
# ref_loss1.backward()
# comm_bwd_end_time = time.time()
# print(f"loss_1 comm bwd runtime {comm_bwd_end_time - comm_bwd_start_time}")
def test_dx_dw_attn_benchmark():
device = "cuda:0"
model = Attention(dim=4096).to(device=device)
# print(f"model numel {get_model_numel(model)}") # 4GB
x1 = torch.rand(1, 256, 4096).to(device=device)
# x2 = torch.rand(1, 256, 4096).to(device=device)
ref_model = deepcopy(model)
ref_x1 = x1.clone()
# ref_x2 = x1.clone()
# first step
x1.requires_grad_()
# x2.requires_grad_()
ref_x1.requires_grad_()
# ref_x2.requires_grad_()
# loss for dx_dw bwd
loss1 = model(x1).sum()
# loss2 = model(x2).sum()
# loss for common bwd
ref_model(ref_x1).sum()
# ref_loss2 = ref_model(ref_x2).sum()
with torch.profiler.profile(
activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA],
# schedule=torch.profiler.schedule(wait=1, warmup=1, active=3, repeat=1),
on_trace_ready=torch.profiler.tensorboard_trace_handler(
f"/home/nvme-share/home/duanjunwen/ColossalAI/tests/test_pipeline/test_schedule"
),
record_shapes=True,
profile_memory=True,
with_stack=True,
with_flops=True,
) as prof:
# dx1
torch.cuda.synchronize()
bwd_b_start_time = time.time()
backward_b(loss1, x1, model)
bwd_b_end_time = time.time()
print(f"loss_1 bwd B runtime {bwd_b_end_time - bwd_b_start_time}")
for p in model.parameters():
assert p.grad is None
assert x1.grad is not None
# dw1
torch.cuda.synchronize()
bwd_w_start_time = time.time()
backward_w(loss1, model)
bwd_w_end_time = time.time()
print(f"loss_1 bwd W runtime {bwd_w_end_time - bwd_w_start_time}")
for p in model.parameters():
assert p.grad is not None
# # common bwd 1
# torch.cuda.synchronize()
# comm_bwd_start_time = time.time()
# ref_loss1.backward()
# comm_bwd_end_time = time.time()
# print(f"loss_1 comm bwd runtime {comm_bwd_end_time - comm_bwd_start_time}")
if __name__ == "__main__":
# test_dx_dw_split()
# test_double_dx_dw_split_nsync()
# test_double_dx_dw_split_sync()
# mem_dx_dw()
# activation_dx_dw()
# test_dx_dw_linear_benchmark()
test_dx_dw_attn_benchmark()