mirror of https://github.com/hpcaitech/ColossalAI
[feat] Linear1D_COL/ROW support zbv WeightGradStore;
parent
0ca16d5cbe
commit
cfade4c36d
|
@ -11,6 +11,7 @@ from colossalai.interface import OptimizerWrapper
|
|||
from colossalai.pipeline.p2p import PipelineP2PCommunication
|
||||
from colossalai.pipeline.schedule.v_schedule import ScheduledNode
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.pipeline.weight_grad_store import WeightGradStore
|
||||
|
||||
from ._utils import (
|
||||
clone,
|
||||
|
@ -650,10 +651,10 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||
# Do not release_tensor_data loss, release_tensor_data other output_obj;
|
||||
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||
self.output_tensors[model_chunk_id].append(output_obj)
|
||||
self.output_tensors_dw[model_chunk_id].append(output_obj)
|
||||
# self.output_tensors_dw[model_chunk_id].append(output_obj)
|
||||
else:
|
||||
self.output_tensors[model_chunk_id].append(output_obj)
|
||||
self.output_tensors_dw[model_chunk_id].append(output_obj)
|
||||
# self.output_tensors_dw[model_chunk_id].append(output_obj)
|
||||
|
||||
# add output to send_fwd_buffer
|
||||
if model_chunk_id == 0: # chunk 0
|
||||
|
@ -705,13 +706,13 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||
input_obj = self.input_tensors[model_chunk_id].pop(0)
|
||||
output_obj = self.output_tensors[model_chunk_id].pop(0)
|
||||
|
||||
# save output_tensor_grad for dw
|
||||
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||
# we save loss here
|
||||
self.output_tensors_grad_dw[model_chunk_id].append(output_obj)
|
||||
else:
|
||||
# we save output_tensor_grad here
|
||||
self.output_tensors_grad_dw[model_chunk_id].append(output_tensor_grad)
|
||||
# # save output_tensor_grad for dw
|
||||
# if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
|
||||
# # we save loss here
|
||||
# self.output_tensors_grad_dw[model_chunk_id].append(output_obj)
|
||||
# else:
|
||||
# # we save output_tensor_grad here
|
||||
# self.output_tensors_grad_dw[model_chunk_id].append(output_tensor_grad)
|
||||
|
||||
# Step2: bwd step
|
||||
input_object_grad = self.backward_b_step(
|
||||
|
@ -738,6 +739,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||
# send to next
|
||||
else:
|
||||
self.send_backward_buffer[model_chunk_id].append(input_object_grad)
|
||||
WeightGradStore.flush(chunk=model_chunk_id)
|
||||
|
||||
def schedule_w(
|
||||
self,
|
||||
|
@ -757,16 +759,18 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||
"""
|
||||
|
||||
# get y & dy from buffer
|
||||
output_obj = self.output_tensors_dw[model_chunk_id].pop(0)
|
||||
output_obj_grad = self.output_tensors_grad_dw[model_chunk_id].pop(0)
|
||||
# output_obj = self.output_tensors_dw[model_chunk_id].pop(0)
|
||||
# output_obj_grad = self.output_tensors_grad_dw[model_chunk_id].pop(0)
|
||||
|
||||
self.backward_w_step(
|
||||
model_chunk=model_chunk,
|
||||
model_chunk_id=model_chunk_id,
|
||||
optimizer=optimizer,
|
||||
output_obj=output_obj,
|
||||
output_obj_grad=output_obj_grad,
|
||||
)
|
||||
WeightGradStore.pop(chunk=model_chunk_id)
|
||||
|
||||
# self.backward_w_step(
|
||||
# model_chunk=model_chunk,
|
||||
# model_chunk_id=model_chunk_id,
|
||||
# optimizer=optimizer,
|
||||
# output_obj=output_obj,
|
||||
# output_obj_grad=output_obj_grad,
|
||||
# )
|
||||
|
||||
def run_forward_only(
|
||||
self,
|
||||
|
|
|
@ -0,0 +1,106 @@
|
|||
import queue
|
||||
|
||||
# from megatron import get_args
|
||||
# from megatron.core import parallel_state
|
||||
# from megatron.core.distributed.finalize_model_grads import _allreduce_embedding_grads
|
||||
# from megatron.core.utils import get_model_config, get_attr_wrapped_model
|
||||
|
||||
|
||||
class WeightGradStore:
|
||||
|
||||
cache = []
|
||||
weight_grad_queue = [queue.Queue(), queue.Queue()]
|
||||
|
||||
@classmethod
|
||||
def put(cls, total_input, grad_output, weight, func):
|
||||
# func(total_input, grad_output, weight.main_grad)
|
||||
cls.cache.append((total_input, grad_output, weight, func))
|
||||
|
||||
@classmethod
|
||||
def flush(cls, chunk=0):
|
||||
cls.weight_grad_queue[chunk].put(cls.cache)
|
||||
cls.cache = []
|
||||
|
||||
@classmethod
|
||||
def pop(cls, chunk=0):
|
||||
if cls.weight_grad_queue[chunk].qsize() > 0:
|
||||
stored_grads = cls.weight_grad_queue[chunk].get()
|
||||
for total_input, grad_output, weight, func in stored_grads:
|
||||
if weight.grad is not None:
|
||||
func(total_input, grad_output, weight.grad)
|
||||
# for first bwd; weight.grad is None, assign grad_weight to weight.grad
|
||||
else:
|
||||
grad_weight = func(total_input, grad_output)
|
||||
weight.grad = grad_weight
|
||||
else:
|
||||
raise Exception("Pop empty queue.")
|
||||
|
||||
# @classmethod
|
||||
# def clear(cls, model, chunk=0):
|
||||
# weight_grad_tasks = []
|
||||
# while cls.weight_grad_queue[chunk].qsize() > 0:
|
||||
# stored_grads = cls.weight_grad_queue[chunk].get()
|
||||
# if len(weight_grad_tasks) == 0:
|
||||
# for _ in stored_grads:
|
||||
# weight_grad_tasks.append([])
|
||||
# else:
|
||||
# assert len(weight_grad_tasks) == len(stored_grads)
|
||||
# for i, task in enumerate(stored_grads):
|
||||
# weight_grad_tasks[i].append(task)
|
||||
# weight_params = []
|
||||
# handles = []
|
||||
# if get_args().overlap_grad_reduce:
|
||||
# handles += model.async_reduce_grad()
|
||||
|
||||
# output_layer_weight = None
|
||||
# if parallel_state.is_pipeline_last_stage():
|
||||
# assert len(weight_grad_tasks) > 0
|
||||
# output_layer_grads = weight_grad_tasks[0]
|
||||
# for j in range(len(output_layer_grads)):
|
||||
# total_input, grad_output, weight, func = output_layer_grads[j]
|
||||
# if output_layer_weight is None:
|
||||
# output_layer_weight = weight
|
||||
# assert output_layer_weight is weight
|
||||
# func(total_input, grad_output, weight.main_grad)
|
||||
# output_layer_grads[j] = None # release memory
|
||||
# weight_grad_tasks = weight_grad_tasks[1:]
|
||||
# if get_args().overlap_grad_reduce:
|
||||
# handles += model.async_reduce_grad(output_layer_weight)
|
||||
|
||||
# if parallel_state.is_pipeline_first_stage() or parallel_state.is_pipeline_last_stage():
|
||||
# model_module = get_attr_wrapped_model(model, 'pre_process', return_model_obj=True)
|
||||
# if model_module.share_embeddings_and_output_weights:
|
||||
# # if share_embeddings_and_output_weights, wait all-reduce for embeddings
|
||||
# for handle in handles:
|
||||
# if handle is not None:
|
||||
# handle.wait()
|
||||
# handles = []
|
||||
|
||||
# config = get_model_config(model)
|
||||
# # Do async all-reduce for embedding grads firstly, so that the rank 0 won't
|
||||
# # be blocked
|
||||
# embedding_handles = _allreduce_embedding_grads([model], config, async_op=True)
|
||||
# handles += embedding_handles
|
||||
|
||||
# for i in range(len(weight_grad_tasks)):
|
||||
# tasks = weight_grad_tasks[i]
|
||||
# param = None
|
||||
# for j in range(len(tasks)):
|
||||
# total_input, grad_output, weight, func = tasks[j]
|
||||
# if param is None:
|
||||
# param = weight
|
||||
# assert param is weight
|
||||
# assert not (weight is output_layer_weight)
|
||||
# func(total_input, grad_output, weight.main_grad)
|
||||
# tasks[j] = None # release memory
|
||||
# weight_params.append(param)
|
||||
# if get_args().overlap_grad_reduce:
|
||||
# # All-reduce param grad here
|
||||
# handles += model.async_reduce_grad(param)
|
||||
# weight_grad_tasks[i] = None # release memory
|
||||
|
||||
# # timers('wait_all_reduce', log_level=1).start(barrier=False)
|
||||
# for handle in embedding_handles:
|
||||
# if handle is not None:
|
||||
# handle.wait()
|
||||
# # timers('wait_all_reduce').stop()
|
|
@ -1,7 +1,11 @@
|
|||
import functools
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn.functional as F
|
||||
|
||||
from colossalai.pipeline.weight_grad_store import WeightGradStore
|
||||
|
||||
from .utils import is_share_sp_tp
|
||||
|
||||
try:
|
||||
|
@ -125,12 +129,13 @@ class LinearWithAsyncCommunication(torch.autograd.Function):
|
|||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_, weight, bias, process_group, async_grad_allreduce, fp8_communication=False):
|
||||
def forward(ctx, input_, weight, bias, process_group, async_grad_allreduce, fp8_communication=False, use_zbv=True):
|
||||
ctx.save_for_backward(input_, weight, bias)
|
||||
ctx.use_bias = bias is not None
|
||||
ctx.process_group = process_group
|
||||
ctx.async_grad_allreduce = async_grad_allreduce
|
||||
ctx.fp8_communication = fp8_communication
|
||||
ctx.use_zbv = use_zbv
|
||||
if bias is not None:
|
||||
output = F.linear(input_, weight, bias)
|
||||
else:
|
||||
|
@ -143,6 +148,14 @@ class LinearWithAsyncCommunication(torch.autograd.Function):
|
|||
input, weight, bias = ctx.saved_tensors
|
||||
use_bias = ctx.use_bias
|
||||
fp8_communication = ctx.fp8_communication
|
||||
use_zbv = ctx.use_zbv
|
||||
|
||||
def execute_w_pass_grad_accum(_input_, _grad_output_, _weight_main_grad_, wgrad_gemm_accum_func=None):
|
||||
wgrad_gemm_accum_func(_input_, _grad_output_, _weight_main_grad_)
|
||||
|
||||
def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_func=None):
|
||||
# _grad_output_.t().matmul(_input_)
|
||||
return wgrad_gemm_func(_grad_output_.t(), _input_)
|
||||
|
||||
# In order to be hooked into Gemini's '__torch_function__', adding a view operation to bias.
|
||||
if use_bias:
|
||||
|
@ -167,22 +180,60 @@ class LinearWithAsyncCommunication(torch.autograd.Function):
|
|||
|
||||
if _grad_accum_fusion_available and weight.grad is not None:
|
||||
grad = weight.grad
|
||||
if grad.dtype == torch.float32:
|
||||
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(total_input, grad_output, grad)
|
||||
grad_weight = None
|
||||
elif grad.dtype == torch.float16:
|
||||
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(total_input, grad_output, grad)
|
||||
if use_zbv:
|
||||
# TODO: append input, grad_output_, weight, grad func to WeightGradStore
|
||||
if grad.dtype == torch.float32:
|
||||
WeightGradStore.put(
|
||||
total_input,
|
||||
grad_output,
|
||||
weight,
|
||||
functools.partial(
|
||||
execute_w_pass_grad_accum,
|
||||
wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32,
|
||||
),
|
||||
)
|
||||
grad_weight = None
|
||||
elif grad.dtype in (torch.float16, torch.bfloat16):
|
||||
WeightGradStore.put(
|
||||
total_input,
|
||||
grad_output,
|
||||
weight,
|
||||
functools.partial(
|
||||
execute_w_pass_grad_accum,
|
||||
wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16,
|
||||
),
|
||||
)
|
||||
grad_weight = None
|
||||
else:
|
||||
raise RuntimeError("Unsupported gradient type for gradient accumulation fusion")
|
||||
else:
|
||||
if grad.dtype == torch.float32:
|
||||
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(total_input, grad_output, grad)
|
||||
grad_weight = None
|
||||
elif grad.dtype == torch.float16:
|
||||
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(total_input, grad_output, grad)
|
||||
grad_weight = None
|
||||
else:
|
||||
grad_weight = grad_output.t().matmul(total_input)
|
||||
else:
|
||||
if use_zbv:
|
||||
WeightGradStore.put(
|
||||
total_input,
|
||||
grad_output,
|
||||
weight,
|
||||
functools.partial(
|
||||
execute_w_pass,
|
||||
wgrad_gemm_func=torch.matmul,
|
||||
),
|
||||
)
|
||||
grad_weight = None
|
||||
else:
|
||||
grad_weight = grad_output.t().matmul(total_input)
|
||||
else:
|
||||
grad_weight = grad_output.t().matmul(total_input)
|
||||
|
||||
grad_bias = grad_output.sum(dim=0) if use_bias else None
|
||||
|
||||
if ctx.async_grad_allreduce and not fp8_communication:
|
||||
handle.wait()
|
||||
|
||||
return grad_input, grad_weight, grad_bias, None, None, None, None
|
||||
|
||||
|
||||
|
|
|
@ -201,7 +201,6 @@ class Linear1D_Col(ParallelModule):
|
|||
|
||||
# Matrix multiply.
|
||||
bias = self.bias if not self.skip_bias_add else None
|
||||
|
||||
if self.seq_parallel_mode == "split_gather":
|
||||
input_parallel = gather_forward_reducescatter_backward(
|
||||
input_parallel, self.process_group, self.seq_parallel_dim, fp8_communication=self.fp8_communication
|
||||
|
|
|
@ -5,6 +5,8 @@ import warnings
|
|||
from contextlib import nullcontext
|
||||
|
||||
import torch
|
||||
|
||||
torch.autograd.set_detect_anomaly(True)
|
||||
import torch.distributed as dist
|
||||
from data_utils import RandomDataset
|
||||
from model_utils import format_numel_str, get_model_numel
|
||||
|
@ -251,6 +253,7 @@ def main():
|
|||
use_fp8=args.use_fp8,
|
||||
fp8_communication=args.use_fp8_comm,
|
||||
scheduler_nodes=scheduler_nodes,
|
||||
make_vocab_size_divisible_by=1,
|
||||
**hybrid_kwargs,
|
||||
)
|
||||
elif args.plugin == "3d_cpu":
|
||||
|
|
|
@ -926,5 +926,6 @@ def test_pp():
|
|||
)
|
||||
|
||||
|
||||
# python -m pytest -s tests/test_pipeline/test_schedule/test_zerobubble_pp.py
|
||||
if __name__ == "__main__":
|
||||
test_pp()
|
||||
|
|
|
@ -0,0 +1,628 @@
|
|||
import gc
|
||||
import time
|
||||
from copy import deepcopy
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.testing import assert_close
|
||||
|
||||
|
||||
def get_model_numel(model):
|
||||
return sum(p.numel() for p in model.parameters()) / 1024**2
|
||||
|
||||
|
||||
# Step1: dx = w*dy
|
||||
def backward_b(loss, x, model):
|
||||
torch.autograd.backward(loss, inputs=x, retain_graph=True)
|
||||
|
||||
|
||||
# Step2: dummy dw = x*dy
|
||||
def backward_w(loss, model):
|
||||
torch.autograd.backward(loss, inputs=list(model.parameters()))
|
||||
|
||||
|
||||
def test_double_dx_dw_split_nsync():
|
||||
device = "cuda:0"
|
||||
model = nn.Linear(4096, 4096, bias=None).to(device=device)
|
||||
# print(f"model numel {get_model_numel(model)}") # 4GB
|
||||
x1 = torch.rand(4096, 4096).to(device=device)
|
||||
x2 = torch.rand(4096, 4096).to(device=device)
|
||||
ref_model = deepcopy(model)
|
||||
ref_x1 = x1.clone()
|
||||
ref_x2 = x1.clone()
|
||||
|
||||
# first step
|
||||
x1.requires_grad_()
|
||||
x2.requires_grad_()
|
||||
ref_x1.requires_grad_()
|
||||
ref_x2.requires_grad_()
|
||||
|
||||
# loss for dx_dw bwd
|
||||
loss1 = model(x1).sum()
|
||||
loss2 = model(x2).sum()
|
||||
|
||||
# loss for common bwd
|
||||
ref_loss1 = ref_model(ref_x1).sum()
|
||||
ref_loss2 = ref_model(ref_x2).sum()
|
||||
|
||||
# dx1
|
||||
torch.cuda.synchronize()
|
||||
bwd_b_start_time = time.time()
|
||||
backward_b(loss1, x1, model)
|
||||
bwd_b_end_time = time.time()
|
||||
print(f"loss_1 bwd B runtime {bwd_b_end_time - bwd_b_start_time}")
|
||||
|
||||
for p in model.parameters():
|
||||
assert p.grad is None
|
||||
assert x1.grad is not None
|
||||
|
||||
# dx2
|
||||
torch.cuda.synchronize()
|
||||
bwd_b_start_time = time.time()
|
||||
backward_b(loss2, x2, model)
|
||||
bwd_b_end_time = time.time()
|
||||
print(f"loss_2 bwd B runtime {bwd_b_end_time - bwd_b_start_time}")
|
||||
|
||||
# dw1
|
||||
torch.cuda.synchronize()
|
||||
bwd_w_start_time = time.time()
|
||||
backward_w(loss1, model)
|
||||
bwd_w_end_time = time.time()
|
||||
print(f"loss_1 bwd W runtime {bwd_w_end_time - bwd_w_start_time}")
|
||||
for p in model.parameters():
|
||||
assert p.grad is not None
|
||||
|
||||
# common bwd 1
|
||||
torch.cuda.synchronize()
|
||||
comm_bwd_start_time = time.time()
|
||||
ref_loss1.backward()
|
||||
comm_bwd_end_time = time.time()
|
||||
print(f"loss_1 comm bwd runtime {comm_bwd_end_time - comm_bwd_start_time}")
|
||||
|
||||
# # assert dx1 & dw1 == bwd 1
|
||||
# assert_close(x1.grad, ref_x1.grad)
|
||||
# for p1, p2 in zip(model.parameters(), ref_model.parameters()):
|
||||
# assert_close(p1, p2)
|
||||
# assert_close(p1.grad, p2.grad)
|
||||
|
||||
# dw2
|
||||
torch.cuda.synchronize()
|
||||
bwd_w_start_time = time.time()
|
||||
backward_w(loss2, model)
|
||||
bwd_w_end_time = time.time()
|
||||
print(f"loss_2 bwd W runtime {bwd_w_end_time - bwd_w_start_time}")
|
||||
|
||||
# common bwd 2
|
||||
torch.cuda.synchronize()
|
||||
comm_bwd_start_time = time.time()
|
||||
ref_loss2.backward()
|
||||
comm_bwd_end_time = time.time()
|
||||
print(f"loss_2 comm bwd runtime {comm_bwd_end_time - comm_bwd_start_time}")
|
||||
|
||||
# # assert dx2 & dw2 == bwd 2
|
||||
# assert_close(x2.grad, ref_x2.grad)
|
||||
# for p1, p2 in zip(model.parameters(), ref_model.parameters()):
|
||||
# print(f"bwd2:\n p1 {p1.grad},\n p2 {p2.grad}\n")
|
||||
# assert_close(p1, p2)
|
||||
# assert_close(p1.grad, p2.grad)
|
||||
|
||||
|
||||
def test_double_dx_dw_split_sync():
|
||||
device = "cuda:0"
|
||||
model = nn.Linear(8, 8, bias=None).to(device=device)
|
||||
print(f"model size {get_model_numel(model)} ") # 4GB
|
||||
x1 = torch.rand(8, 8).to(device=device)
|
||||
x2 = torch.rand(8, 8).to(device=device)
|
||||
|
||||
# x1 = torch.ones(8, 8).to(device=device)
|
||||
# x2 = torch.ones(8, 8).to(device=device)
|
||||
|
||||
ref_model = deepcopy(model)
|
||||
ref_x1 = x1.clone()
|
||||
ref_x2 = x2.clone()
|
||||
|
||||
x1.requires_grad_()
|
||||
x2.requires_grad_()
|
||||
ref_x1.requires_grad_()
|
||||
ref_x2.requires_grad_()
|
||||
|
||||
############
|
||||
# step1:
|
||||
############
|
||||
|
||||
# loss1
|
||||
loss1 = model(x1).sum()
|
||||
|
||||
# ref_loss1
|
||||
ref_model(ref_x1).sum()
|
||||
|
||||
# dx1
|
||||
backward_b(loss1, x1, model)
|
||||
for p in model.parameters():
|
||||
assert p.grad is None
|
||||
assert x1.grad is not None
|
||||
|
||||
# dw1
|
||||
backward_w(loss1, model)
|
||||
for p in model.parameters():
|
||||
assert p.grad is not None
|
||||
|
||||
# common bwd 1
|
||||
# ref_loss1.backward()
|
||||
|
||||
# assert dx1 & dw1 == bwd 1
|
||||
assert_close(x1.grad, ref_x1.grad)
|
||||
for p1, p2 in zip(model.parameters(), ref_model.parameters()):
|
||||
assert_close(p1, p2)
|
||||
assert_close(p1.grad, p2.grad)
|
||||
|
||||
############
|
||||
# step2:
|
||||
############
|
||||
|
||||
# loss2
|
||||
loss2 = model(x2).sum()
|
||||
|
||||
# ref_loss2
|
||||
ref_loss2 = ref_model(ref_x2).sum()
|
||||
|
||||
for p1, p2 in zip(model.parameters(), ref_model.parameters()):
|
||||
print(f"bwd2:\n p1 {p1.grad},\n p2 {p2.grad}\n")
|
||||
assert_close(p1, p2)
|
||||
assert_close(p1.grad, p2.grad)
|
||||
|
||||
# dx2
|
||||
backward_b(loss2, x2, model)
|
||||
|
||||
# dw2
|
||||
backward_w(loss2, model)
|
||||
|
||||
# common bwd 2
|
||||
ref_loss2.backward()
|
||||
|
||||
# assert dx2 & dw2 == bwd 2
|
||||
assert_close(x2.grad, ref_x2.grad)
|
||||
for p1, p2 in zip(model.parameters(), ref_model.parameters()):
|
||||
print(f"bwd2:\n p1 {p1.grad},\n p2 {p2.grad}\n")
|
||||
assert_close(p1, p2)
|
||||
assert_close(p1.grad, p2.grad)
|
||||
|
||||
|
||||
def deallocate_output_tensor(out):
|
||||
"""Pseudo-deallocate (i.e., set to scalar) the output tensor's '.data' field.
|
||||
|
||||
This method should be called right after the output tensor has been
|
||||
sent to the next pipeline stage. At this point, the output tensor is
|
||||
only useful for its '.grad_fn' field, and not its '.data'.
|
||||
"""
|
||||
assert isinstance(out, torch.Tensor), "expected Tensor, found %s." % type(out).__name__
|
||||
assert out._base is None, "counter-productive to free a view of another tensor."
|
||||
out.data = torch.empty(
|
||||
(1,),
|
||||
device=out.device,
|
||||
dtype=out.dtype,
|
||||
)
|
||||
|
||||
|
||||
IN_DIM = 8192
|
||||
OUT_DIM = 8192
|
||||
NUM_LAYER = 3
|
||||
|
||||
|
||||
class MlpModel(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList([nn.Linear(IN_DIM, OUT_DIM, bias=None) for _ in range(NUM_LAYER)])
|
||||
|
||||
def forward(self, x):
|
||||
for layer in self.layers:
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.0, proj_drop=0.0, with_qkv=True):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim // num_heads
|
||||
self.scale = qk_scale or head_dim**-0.5
|
||||
self.with_qkv = with_qkv
|
||||
if self.with_qkv:
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
|
||||
def forward(self, x):
|
||||
B, N, C = x.shape
|
||||
if self.with_qkv:
|
||||
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = qkv[0], qkv[1], qkv[2]
|
||||
else:
|
||||
qkv = x.reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
|
||||
q, k, v = qkv, qkv, qkv
|
||||
|
||||
attn = (q @ k.transpose(-2, -1)) * self.scale
|
||||
attn = attn.softmax(dim=-1)
|
||||
attn = self.attn_drop(attn)
|
||||
|
||||
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
||||
if self.with_qkv:
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
|
||||
|
||||
def mem_dx_dw():
|
||||
device = "cuda:0"
|
||||
# model = nn.Linear(IN_DIM, OUT_DIM, bias=None).to(device=device)
|
||||
print(f"Before init Model: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
|
||||
model = MlpModel().to(device=device)
|
||||
print(f"model numel {get_model_numel(model)}") # 4GB
|
||||
print(f"After init Model: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
|
||||
|
||||
print(f"Before init x1&2&3: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
|
||||
x1 = torch.rand(IN_DIM, OUT_DIM).to(device=device)
|
||||
x2 = torch.rand(IN_DIM, OUT_DIM).to(device=device)
|
||||
x3 = torch.rand(IN_DIM, OUT_DIM).to(device=device)
|
||||
|
||||
x1.requires_grad_()
|
||||
x2.requires_grad_()
|
||||
x3.requires_grad_()
|
||||
print(f"After init x1&2&3: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
|
||||
|
||||
############
|
||||
# step1:
|
||||
############
|
||||
print(f"\nStep1")
|
||||
|
||||
# loss1
|
||||
print(f"Before Fwd x1: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
|
||||
y1 = model(x1)
|
||||
print(f"After Fwd x1: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
|
||||
|
||||
print(f"Before loss1: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
|
||||
loss1 = y1.sum()
|
||||
print(f"After loss1: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
|
||||
|
||||
# dx1
|
||||
backward_b(loss1, x1, model)
|
||||
|
||||
# dw1
|
||||
backward_w(loss1, model)
|
||||
|
||||
deallocate_output_tensor(x1)
|
||||
deallocate_output_tensor(y1)
|
||||
# del x1
|
||||
# del y1
|
||||
print(f"After del x1&y1: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
|
||||
|
||||
# print(f"\n Step1:collect:{gc.collect()}")
|
||||
# print(f"object: {gc.get_objects()}")
|
||||
# print(f"garbage: {gc.garbage}")
|
||||
|
||||
############
|
||||
# step2:
|
||||
############
|
||||
print(f"\nStep2")
|
||||
|
||||
# loss2
|
||||
y2 = model(x2)
|
||||
loss2 = y2.sum()
|
||||
|
||||
# dx2
|
||||
backward_b(loss2, x2, model)
|
||||
|
||||
# dw2
|
||||
backward_w(loss2, model)
|
||||
deallocate_output_tensor(x2)
|
||||
deallocate_output_tensor(y2)
|
||||
# del x2
|
||||
# del y2
|
||||
print(f"After del x2&y2: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
|
||||
|
||||
print(f"\n Step2:collect:{gc.collect()}")
|
||||
# print(f"object: {gc.get_objects()}")
|
||||
print(f"garbage: {gc.garbage}")
|
||||
|
||||
############
|
||||
# step3:
|
||||
############
|
||||
|
||||
print(f"\nStep3")
|
||||
|
||||
# loss3
|
||||
y3 = model(x3)
|
||||
loss3 = y3.sum()
|
||||
|
||||
# dx2
|
||||
backward_b(loss3, x3, model)
|
||||
|
||||
# dw2
|
||||
backward_w(loss3, model)
|
||||
|
||||
deallocate_output_tensor(x3)
|
||||
deallocate_output_tensor(y3)
|
||||
# del x3
|
||||
# del y3
|
||||
|
||||
print(f"After del x3&y3: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
|
||||
|
||||
print(f"\n Step3:collect:{gc.collect()}")
|
||||
# print(f"object: {gc.get_objects()}")
|
||||
print(f"garbage: {gc.garbage}")
|
||||
|
||||
|
||||
# del activation
|
||||
def activation_dx_dw():
|
||||
device = "cuda:0"
|
||||
# model = nn.Linear(IN_DIM, OUT_DIM, bias=None).to(device=device)
|
||||
print(f"Before init Model: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
|
||||
model = MlpModel().to(device=device)
|
||||
x1 = torch.rand(IN_DIM, OUT_DIM).to(device=device)
|
||||
x2 = torch.rand(IN_DIM, OUT_DIM).to(device=device)
|
||||
x3 = torch.rand(IN_DIM, OUT_DIM).to(device=device)
|
||||
|
||||
x1.requires_grad_()
|
||||
x2.requires_grad_()
|
||||
x3.requires_grad_()
|
||||
print(f"After init Model, x1,x2,x3: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
|
||||
|
||||
activations = {}
|
||||
|
||||
def register_hooks(module):
|
||||
def activation_hook(module, input, output):
|
||||
activations[f"{module.__class__.__name__}_{id(module)}"] = output.detach()
|
||||
|
||||
def bwd_hook(module, grad_input, grad_output):
|
||||
del activations[f"{module.__class__.__name__}_{id(module)}"]
|
||||
|
||||
module.register_forward_hook(activation_hook)
|
||||
module.register_backward_hook(bwd_hook)
|
||||
|
||||
model.apply(register_hooks)
|
||||
|
||||
############
|
||||
# step1:
|
||||
############
|
||||
print(f"\nStep1")
|
||||
|
||||
# loss1
|
||||
loss1 = model(x1).sum()
|
||||
|
||||
# dx1
|
||||
backward_b(loss1, x1, model)
|
||||
|
||||
# dw1
|
||||
backward_w(loss1, model)
|
||||
|
||||
del loss1, x1
|
||||
print(f"After del x1&y1: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
|
||||
|
||||
############
|
||||
# step2:
|
||||
############
|
||||
print(f"\nStep2")
|
||||
|
||||
# loss2
|
||||
loss2 = model(x2).sum()
|
||||
|
||||
# dx2
|
||||
backward_b(loss2, x2, model)
|
||||
|
||||
# dw2
|
||||
backward_w(loss2, model)
|
||||
|
||||
# deallocate_output_tensor(x2)
|
||||
# deallocate_output_tensor(loss2)
|
||||
del x2, loss2
|
||||
print(f"After del x2&y2: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
|
||||
|
||||
############
|
||||
# step3:
|
||||
############
|
||||
print(f"\nStep3")
|
||||
|
||||
# loss3
|
||||
loss3 = model(x3).sum()
|
||||
|
||||
# dx2
|
||||
backward_b(loss3, x3, model)
|
||||
|
||||
# dw2
|
||||
backward_w(loss3, model)
|
||||
|
||||
del x3, loss3
|
||||
|
||||
print(f"After del x3&y3: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
|
||||
|
||||
|
||||
# text dx dw in model chunk
|
||||
def model_chunk_dx_dw():
|
||||
device = "cuda:0"
|
||||
num_layers = 4
|
||||
print(f"Before init Model: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
|
||||
model = MlpModel(in_dim=4096, out_dim=4096, num_layers=num_layers).to(device=device)
|
||||
x = torch.rand(4096, 4096).to(device=device)
|
||||
x.requires_grad_()
|
||||
|
||||
model_chunk_0 = torch.nn.ModuleList() # for layer 1 & 2
|
||||
model_chunk_1 = torch.nn.ModuleList() # for layer 3 & 4
|
||||
|
||||
for idx, sub_model in enumerate(model.layers):
|
||||
if idx < 2:
|
||||
model_chunk_0.append(sub_model).cuda()
|
||||
else:
|
||||
model_chunk_1.append(sub_model).cuda()
|
||||
|
||||
print(f"After init Model & input: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
|
||||
|
||||
# Step1:chunk 0 fwd
|
||||
activation = dict() # layer_id: activation
|
||||
out = x
|
||||
for i in range(len(model_chunk_0)):
|
||||
layer = model_chunk_0[i]
|
||||
activation[i] = layer(out)
|
||||
print(f"After chunk0 fwd: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
|
||||
# Step2:chunk 1 fwd
|
||||
for i in range(len(model_chunk_1)):
|
||||
layer = model_chunk_0[i]
|
||||
activation[i + 2] = layer(out)
|
||||
print(f"After chunk1 fwd: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
|
||||
|
||||
# Step3:chunk 1 bwd b: dx=w*dy & bwd w:dw=x*dy
|
||||
# visit layer reversely
|
||||
for i in range(len(model_chunk_1) - 1, -1, -1):
|
||||
layer = model_chunk_1[i]
|
||||
global_layer_idx = i + 2
|
||||
prev_global_layer_idx = i + 1 if i + 1 > 0 else None
|
||||
i + 3 if i + 3 < 4 else None
|
||||
|
||||
# bwd b
|
||||
if global_layer_idx == num_layers - 1: # last layer in last chunk; calculate loss
|
||||
loss = activation[global_layer_idx].sum()
|
||||
x = activation[prev_global_layer_idx]
|
||||
backward_b(loss, x, layer)
|
||||
else:
|
||||
loss = activation[global_layer_idx].sum()
|
||||
x = activation[prev_global_layer_idx]
|
||||
backward_b(loss, x, layer)
|
||||
|
||||
# bwd w
|
||||
backward_w(loss, layer)
|
||||
|
||||
|
||||
def test_dx_dw_linear_benchmark():
|
||||
device = "cuda:0"
|
||||
model = nn.Linear(4096, 4096, bias=None).to(device=device)
|
||||
# print(f"model numel {get_model_numel(model)}") # 4GB
|
||||
x1 = torch.rand(4096, 4096).to(device=device)
|
||||
# x2 = torch.rand(4096, 4096).to(device=device)
|
||||
ref_model = deepcopy(model)
|
||||
ref_x1 = x1.clone()
|
||||
# ref_x2 = x1.clone()
|
||||
|
||||
# first step
|
||||
x1.requires_grad_()
|
||||
# x2.requires_grad_()
|
||||
ref_x1.requires_grad_()
|
||||
# ref_x2.requires_grad_()
|
||||
|
||||
# loss for dx_dw bwd
|
||||
loss1 = model(x1).sum()
|
||||
# loss2 = model(x2).sum()
|
||||
|
||||
# loss for common bwd
|
||||
ref_model(ref_x1).sum()
|
||||
# ref_loss2 = ref_model(ref_x2).sum()
|
||||
|
||||
with torch.profiler.profile(
|
||||
activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA],
|
||||
# schedule=torch.profiler.schedule(wait=1, warmup=1, active=3, repeat=1),
|
||||
on_trace_ready=torch.profiler.tensorboard_trace_handler(
|
||||
f"/home/nvme-share/home/duanjunwen/ColossalAI/tests/test_pipeline/test_schedule"
|
||||
),
|
||||
record_shapes=True,
|
||||
profile_memory=True,
|
||||
with_stack=True,
|
||||
with_flops=True,
|
||||
) as prof:
|
||||
# dx1
|
||||
torch.cuda.synchronize()
|
||||
bwd_b_start_time = time.time()
|
||||
backward_b(loss1, x1, model)
|
||||
bwd_b_end_time = time.time()
|
||||
print(f"loss_1 bwd B runtime {bwd_b_end_time - bwd_b_start_time}")
|
||||
|
||||
for p in model.parameters():
|
||||
assert p.grad is None
|
||||
assert x1.grad is not None
|
||||
|
||||
# dw1
|
||||
torch.cuda.synchronize()
|
||||
bwd_w_start_time = time.time()
|
||||
backward_w(loss1, model)
|
||||
bwd_w_end_time = time.time()
|
||||
print(f"loss_1 bwd W runtime {bwd_w_end_time - bwd_w_start_time}")
|
||||
for p in model.parameters():
|
||||
assert p.grad is not None
|
||||
|
||||
# # common bwd 1
|
||||
# torch.cuda.synchronize()
|
||||
# comm_bwd_start_time = time.time()
|
||||
# ref_loss1.backward()
|
||||
# comm_bwd_end_time = time.time()
|
||||
# print(f"loss_1 comm bwd runtime {comm_bwd_end_time - comm_bwd_start_time}")
|
||||
|
||||
|
||||
def test_dx_dw_attn_benchmark():
|
||||
device = "cuda:0"
|
||||
model = Attention(dim=4096).to(device=device)
|
||||
# print(f"model numel {get_model_numel(model)}") # 4GB
|
||||
x1 = torch.rand(1, 256, 4096).to(device=device)
|
||||
# x2 = torch.rand(1, 256, 4096).to(device=device)
|
||||
ref_model = deepcopy(model)
|
||||
ref_x1 = x1.clone()
|
||||
# ref_x2 = x1.clone()
|
||||
|
||||
# first step
|
||||
x1.requires_grad_()
|
||||
# x2.requires_grad_()
|
||||
ref_x1.requires_grad_()
|
||||
# ref_x2.requires_grad_()
|
||||
|
||||
# loss for dx_dw bwd
|
||||
loss1 = model(x1).sum()
|
||||
# loss2 = model(x2).sum()
|
||||
|
||||
# loss for common bwd
|
||||
ref_model(ref_x1).sum()
|
||||
# ref_loss2 = ref_model(ref_x2).sum()
|
||||
|
||||
with torch.profiler.profile(
|
||||
activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA],
|
||||
# schedule=torch.profiler.schedule(wait=1, warmup=1, active=3, repeat=1),
|
||||
on_trace_ready=torch.profiler.tensorboard_trace_handler(
|
||||
f"/home/nvme-share/home/duanjunwen/ColossalAI/tests/test_pipeline/test_schedule"
|
||||
),
|
||||
record_shapes=True,
|
||||
profile_memory=True,
|
||||
with_stack=True,
|
||||
with_flops=True,
|
||||
) as prof:
|
||||
# dx1
|
||||
torch.cuda.synchronize()
|
||||
bwd_b_start_time = time.time()
|
||||
backward_b(loss1, x1, model)
|
||||
bwd_b_end_time = time.time()
|
||||
print(f"loss_1 bwd B runtime {bwd_b_end_time - bwd_b_start_time}")
|
||||
|
||||
for p in model.parameters():
|
||||
assert p.grad is None
|
||||
assert x1.grad is not None
|
||||
|
||||
# dw1
|
||||
torch.cuda.synchronize()
|
||||
bwd_w_start_time = time.time()
|
||||
backward_w(loss1, model)
|
||||
bwd_w_end_time = time.time()
|
||||
print(f"loss_1 bwd W runtime {bwd_w_end_time - bwd_w_start_time}")
|
||||
for p in model.parameters():
|
||||
assert p.grad is not None
|
||||
|
||||
# # common bwd 1
|
||||
# torch.cuda.synchronize()
|
||||
# comm_bwd_start_time = time.time()
|
||||
# ref_loss1.backward()
|
||||
# comm_bwd_end_time = time.time()
|
||||
# print(f"loss_1 comm bwd runtime {comm_bwd_end_time - comm_bwd_start_time}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# test_dx_dw_split()
|
||||
# test_double_dx_dw_split_nsync()
|
||||
# test_double_dx_dw_split_sync()
|
||||
# mem_dx_dw()
|
||||
# activation_dx_dw()
|
||||
# test_dx_dw_linear_benchmark()
|
||||
test_dx_dw_attn_benchmark()
|
Loading…
Reference in New Issue