Browse Source

[fix] rm use_zbv flag in Shardconfig; rm debug info;

pull/6083/head
duanjunwen 1 month ago
parent
commit
e76308c6e6
  1. 1
      colossalai/booster/plugin/hybrid_parallel_plugin.py
  2. 1
      colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
  3. 24
      colossalai/shardformer/policies/llama.py
  4. 28
      colossalai/shardformer/policies/mixtral.py
  5. 1
      colossalai/shardformer/shard/shard_config.py
  6. 2
      examples/language/llama/benchmark.py
  7. 176
      tests/test_pipeline/test_schedule/test_zerobubble_pp.py
  8. 628
      tests/test_pipeline/test_schedule/zbv_poc.py
  9. 2
      tests/test_shardformer/test_model/test_shard_llama.py

1
colossalai/booster/plugin/hybrid_parallel_plugin.py

@ -1201,7 +1201,6 @@ class HybridParallelPlugin(PipelinePluginBase):
gradient_checkpoint_config=gradient_checkpoint_config,
fp8_communication=fp8_communication,
inner_ring_size=inner_ring_size,
use_zbv=(pp_style == "zbv"),
)
self.amp_config = dict(
initial_scale=initial_scale,

1
colossalai/booster/plugin/moe_hybrid_parallel_plugin.py

@ -373,7 +373,6 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
make_vocab_size_divisible_by=make_vocab_size_divisible_by,
gradient_checkpoint_config=gradient_checkpoint_config,
fp8_communication=fp8_communication,
use_zbv=(pp_style == "zbv"),
)
self.amp_config = dict(
initial_scale=initial_scale,

24
colossalai/shardformer/policies/llama.py

@ -60,6 +60,11 @@ class LlamaPolicy(Policy):
else:
norm_cls = RMSNorm
if self.pipeline_stage_manager:
use_zbv = self.pipeline_stage_manager.use_zbv
else:
use_zbv = False
sp_mode = self.shard_config.sequence_parallelism_mode or None
sp_size = self.shard_config.sequence_parallel_size or None
sp_group = self.shard_config.sequence_parallel_process_group or None
@ -129,7 +134,7 @@ class LlamaPolicy(Policy):
kwargs=dict(
seq_parallel_mode=sp_mode,
fp8_communication=self.shard_config.fp8_communication,
use_zbv=self.shard_config.use_zbv,
use_zbv=use_zbv,
),
),
SubModuleReplacementDescription(
@ -138,7 +143,7 @@ class LlamaPolicy(Policy):
kwargs=dict(
seq_parallel_mode=sp_mode,
fp8_communication=self.shard_config.fp8_communication,
use_zbv=self.shard_config.use_zbv,
use_zbv=use_zbv,
),
),
SubModuleReplacementDescription(
@ -147,7 +152,7 @@ class LlamaPolicy(Policy):
kwargs=dict(
seq_parallel_mode=sp_mode,
fp8_communication=self.shard_config.fp8_communication,
use_zbv=self.shard_config.use_zbv,
use_zbv=use_zbv,
),
),
SubModuleReplacementDescription(
@ -156,7 +161,7 @@ class LlamaPolicy(Policy):
kwargs=dict(
seq_parallel_mode=sp_mode,
fp8_communication=self.shard_config.fp8_communication,
use_zbv=self.shard_config.use_zbv,
use_zbv=use_zbv,
),
),
SubModuleReplacementDescription(
@ -165,7 +170,7 @@ class LlamaPolicy(Policy):
kwargs=dict(
seq_parallel_mode=sp_mode,
fp8_communication=self.shard_config.fp8_communication,
use_zbv=self.shard_config.use_zbv,
use_zbv=use_zbv,
),
),
SubModuleReplacementDescription(
@ -174,7 +179,7 @@ class LlamaPolicy(Policy):
kwargs=dict(
seq_parallel_mode=sp_mode,
fp8_communication=self.shard_config.fp8_communication,
use_zbv=self.shard_config.use_zbv,
use_zbv=use_zbv,
),
),
SubModuleReplacementDescription(
@ -183,7 +188,7 @@ class LlamaPolicy(Policy):
kwargs=dict(
seq_parallel_mode=sp_mode,
fp8_communication=self.shard_config.fp8_communication,
use_zbv=self.shard_config.use_zbv,
use_zbv=use_zbv,
),
),
],
@ -413,6 +418,10 @@ class LlamaForSequenceClassificationPolicy(LlamaPolicy):
from transformers import LlamaForSequenceClassification
policy = super().module_policy()
if self.pipeline_stage_manager:
use_zbv = self.pipeline_stage_manager.use_zbv
else:
use_zbv = False
if self.shard_config.enable_tensor_parallelism:
# add a new item for sequence classification
@ -425,6 +434,7 @@ class LlamaForSequenceClassificationPolicy(LlamaPolicy):
kwargs=dict(
gather_output=True,
fp8_communication=self.shard_config.fp8_communication,
use_zbv=use_zbv,
),
)
]

28
colossalai/shardformer/policies/mixtral.py

@ -52,6 +52,10 @@ class MixtralPolicy(Policy):
sp_group = self.shard_config.sequence_parallel_process_group or None
sp_partial_derived = sp_mode in ["split_gather", "ring"]
tp_size = self.shard_config.tensor_parallel_size
if self.pipeline_stage_manager:
use_zbv = self.pipeline_stage_manager.use_zbv
else:
use_zbv = False
# modified for both SP and TP
num_q_heads = self.model.config.num_attention_heads
@ -126,7 +130,7 @@ class MixtralPolicy(Policy):
target_module=Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": self.shard_config.use_zbv,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
@ -134,7 +138,7 @@ class MixtralPolicy(Policy):
target_module=Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": self.shard_config.use_zbv,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
@ -142,7 +146,7 @@ class MixtralPolicy(Policy):
target_module=Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": self.shard_config.use_zbv,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
@ -150,7 +154,7 @@ class MixtralPolicy(Policy):
target_module=Linear1D_Row,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": self.shard_config.use_zbv,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
@ -159,7 +163,7 @@ class MixtralPolicy(Policy):
kwargs={
"gather_output": True,
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": self.shard_config.use_zbv,
"use_zbv": use_zbv,
},
),
],
@ -195,7 +199,7 @@ class MixtralPolicy(Policy):
"tp_group": self.shard_config.tensor_parallel_process_group,
"moe_dp_group": self.shard_config.moe_dp_group,
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": self.shard_config.use_zbv,
"use_zbv": use_zbv,
},
)
],
@ -330,6 +334,10 @@ class MixtralModelPolicy(MixtralPolicy):
class MixtralForCausalLMPolicy(MixtralPolicy):
def module_policy(self):
policy = super().module_policy()
if self.pipeline_stage_manager:
use_zbv = self.pipeline_stage_manager.use_zbv
else:
use_zbv = False
# TODO: assign pg mesh from plugin to all modules
if self.shard_config.enable_tensor_parallelism:
# add a new item for causal lm
@ -342,7 +350,7 @@ class MixtralForCausalLMPolicy(MixtralPolicy):
kwargs=dict(
gather_output=True,
fp8_communication=self.shard_config.fp8_communication,
use_zbv=self.shard_config.use_zbv,
use_zbv=use_zbv,
),
)
],
@ -392,6 +400,10 @@ class MixtralForSequenceClassificationPolicy(MixtralPolicy):
from transformers import MixtralForSequenceClassification
policy = super().module_policy()
if self.pipeline_stage_manager:
use_zbv = self.pipeline_stage_manager.use_zbv
else:
use_zbv = False
if self.shard_config.enable_tensor_parallelism:
# add a new item for sequence classification
@ -404,7 +416,7 @@ class MixtralForSequenceClassificationPolicy(MixtralPolicy):
kwargs=dict(
gather_output=True,
fp8_communication=self.shard_config.fp8_communication,
use_zbv=self.shard_config.use_zbv,
use_zbv=use_zbv,
),
)
]

1
colossalai/shardformer/shard/shard_config.py

@ -49,7 +49,6 @@ class ShardConfig:
make_vocab_size_divisible_by: int = 64
gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None
extra_kwargs: Dict[str, Any] = field(default_factory=dict)
use_zbv: bool = False
# For ring attention
inner_ring_size: Optional[int] = None

2
examples/language/llama/benchmark.py

@ -5,8 +5,6 @@ import warnings
from contextlib import nullcontext
import torch
torch.autograd.set_detect_anomaly(True)
import torch.distributed as dist
from data_utils import RandomDataset
from model_utils import format_numel_str, get_model_numel

176
tests/test_pipeline/test_schedule/test_zerobubble_pp.py

@ -8,12 +8,14 @@ import torch
import torch.distributed as dist
import torch.nn as nn
from torch.testing import assert_close
from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.llama.modeling_llama import LlamaModel
from transformers.models.mixtral.configuration_mixtral import MixtralConfig
from transformers.models.mixtral.modeling_mixtral import MixtralModel
import colossalai
from colossalai.booster.booster import Booster
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import HybridParallelPlugin, MoeHybridParallelPlugin
from colossalai.cluster import ProcessGroupMesh
from colossalai.interface import OptimizerWrapper
from colossalai.logging import disable_existing_loggers
@ -918,11 +920,181 @@ def run_with_booster_moehybridplugin(config: Tuple[int, ...]):
torch.cuda.empty_cache()
@parameterize(
"config",
[
(0, 4, 1, 1),
# (1, 2, 2, 1),
# (1, 2, 1, 2),
# (1, 1, 2, 2),
],
)
def run_with_booster_hybridplugin(config: Tuple[int, ...]):
stage, pp_size, tp_size, sp_size = config
num_microbatches = pp_size
dist.get_world_size()
rank = dist.get_rank()
dtype, precision = torch.float16, "fp16"
torch.cuda.set_device(dist.get_rank())
########
# init base model
########
assert pp_size <= NUM_LAYERS, "pp_size should be less than or equal to NUM_LAYERS"
config = LlamaConfig(
hidden_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS,
intermediate_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS * 2,
num_hidden_layers=NUM_LAYERS,
num_attention_heads=NUM_HEADS,
num_key_value_heads=NUM_HEADS,
attn_implementation="flash_attention_2",
)
# init model with the same seed
seed_all(10086)
torch_model = LlamaModel(config).to(dtype).cuda()
# TODO: Support MixtralForCausalLM
# torch_model = MixtralForCausalLM(config).to(dtype).cuda()
torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1)
# init schedule
h, a, s = config.hidden_size, config.num_attention_heads, 1024
mem_f = 34 * h + 5 * a * s
mem_w = -32 * h
mem_b = -mem_w - mem_f
graph = PipelineGraph(
n_stage=pp_size,
n_micro=num_microbatches,
f_cost=1,
b_cost=1,
w_cost=1,
c_cost=1,
f_mem=mem_f,
b_mem=mem_b,
w_mem=mem_w,
)
zbv_schedule = graph.get_v_schedule()
# init MoeHybridPlugin
plugin = HybridParallelPlugin(
pp_size=pp_size,
num_microbatches=pp_size,
tp_size=tp_size,
sp_size=sp_size,
zero_stage=stage,
enable_sequence_parallelism=sp_size > 1,
sequence_parallelism_mode="all_to_all" if sp_size > 1 else None,
overlap_communication=False,
initial_scale=1,
precision=precision,
find_unused_parameters=True,
pp_style="zbv",
scheduler_nodes=zbv_schedule,
num_model_chunks=2,
)
dp_size = plugin.dp_size
booster = Booster(plugin=plugin)
########
# init pp model
########
parallel_model = deepcopy(torch_model)
parallel_optimizer = torch.optim.SGD(parallel_model.parameters(), lr=1)
parallel_model, parallel_optimizer, _, _, _ = booster.boost(parallel_model, parallel_optimizer)
# create different input along dp axis
seed_all(1453 + rank)
torch_model.train()
parallel_model.train()
for i in range(2):
# gen random input
# input = torch.rand(
# NUM_BATCH, NUM_TOK_PER_BATCH, NUM_HEADS, HIDDEN_SIZE_PER_HEAD, requires_grad=True
# ).cuda()
input_ids = torch.randint(0, torch_model.vocab_size, (NUM_BATCH, config.max_position_embeddings)).cuda()
attention_mask = torch.ones_like(input_ids).cuda()
input_ids.clone().cuda()
input_data = {"input_ids": input_ids, "attention_mask": attention_mask}
# dist.all_reduce(
# input, group=plugin.pp_group
# ) # pp inputs except the first stage doesn't matter, but need to be replicate for torch model check
# dist.all_reduce(input, group=plugin.tp_group) # tp group duplicate input
# dist.all_reduce(input, group=plugin.sp_group) # sp group duplicate input
# run the model with hybrid parallel
if booster.plugin.stage_manager is not None:
# for test with pp
data_iter = iter([input_data])
sharded_output = booster.execute_pipeline(
data_iter,
parallel_model,
lambda x, y: x.last_hidden_state.mean(),
parallel_optimizer,
return_loss=True,
return_outputs=True,
)
# stage 0 chunk 0
parallel_output = None
if (
booster.plugin.stage_manager.is_first_stage(ignore_chunk=True)
and rank == dist.get_process_group_ranks(plugin.pp_group)[0]
):
parallel_output = sharded_output["loss"]
else:
parallel_output = torch.tensor(12345.0, device="cuda")
# broadcast along pp axis
dist.broadcast(parallel_output, src=dist.get_process_group_ranks(plugin.pp_group)[0], group=plugin.pp_group)
else:
# for test without pp
parallel_output = parallel_model(
input_ids=input_data["input_ids"],
attention_mask=input_data["attention_mask"],
).last_hidden_state.mean()
parallel_optimizer.backward(parallel_output)
parallel_optimizer.step()
parallel_optimizer.zero_grad()
dist.all_reduce(parallel_output, group=plugin.dp_group)
# ===================================================================================
# run normal model with all dp(different) inputs
all_inputs = [input_data for _ in range(dp_size)]
# dist.all_gather(all_inputs, input, group=plugin.dp_group)
torch_output_sum = 0
for input_data_ in all_inputs:
torch_output = torch_model(
input_ids=input_data_["input_ids"],
attention_mask=input_data_["attention_mask"],
).last_hidden_state.mean()
torch_output.backward()
torch_output_sum += torch_output.detach()
# print(f"parallel_output {parallel_output} torch_output_sum {torch_output_sum}")
# avg dp grads follows zero optimizer
for p in torch_model.parameters():
if p.grad is not None:
p.grad /= dp_size
torch_optimizer.step()
torch_optimizer.zero_grad()
print(f"loop {i} rank {dist.get_rank()} parallel_output {parallel_output} torch_output_sum {torch_output_sum}")
# assert_loose_close(parallel_output, torch_output_sum, dtype=dtype)
# print(f"rank {dist.get_rank()} config {test_config} test passed")
clear_layout_converter()
Randomizer.reset_index()
torch.cuda.empty_cache()
def run_dist(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
# run_fwd_bwd_vschedule_with_optim()
run_with_booster_moehybridplugin()
# run_with_booster_hybridplugin()
@pytest.mark.dist

628
tests/test_pipeline/test_schedule/zbv_poc.py

@ -1,628 +0,0 @@
import gc
import time
from copy import deepcopy
import torch
import torch.nn as nn
from torch.testing import assert_close
def get_model_numel(model):
return sum(p.numel() for p in model.parameters()) / 1024**2
# Step1: dx = w*dy
def backward_b(loss, x, model):
torch.autograd.backward(loss, inputs=x, retain_graph=True)
# Step2: dummy dw = x*dy
def backward_w(loss, model):
torch.autograd.backward(loss, inputs=list(model.parameters()))
def test_double_dx_dw_split_nsync():
device = "cuda:0"
model = nn.Linear(4096, 4096, bias=None).to(device=device)
# print(f"model numel {get_model_numel(model)}") # 4GB
x1 = torch.rand(4096, 4096).to(device=device)
x2 = torch.rand(4096, 4096).to(device=device)
ref_model = deepcopy(model)
ref_x1 = x1.clone()
ref_x2 = x1.clone()
# first step
x1.requires_grad_()
x2.requires_grad_()
ref_x1.requires_grad_()
ref_x2.requires_grad_()
# loss for dx_dw bwd
loss1 = model(x1).sum()
loss2 = model(x2).sum()
# loss for common bwd
ref_loss1 = ref_model(ref_x1).sum()
ref_loss2 = ref_model(ref_x2).sum()
# dx1
torch.cuda.synchronize()
bwd_b_start_time = time.time()
backward_b(loss1, x1, model)
bwd_b_end_time = time.time()
print(f"loss_1 bwd B runtime {bwd_b_end_time - bwd_b_start_time}")
for p in model.parameters():
assert p.grad is None
assert x1.grad is not None
# dx2
torch.cuda.synchronize()
bwd_b_start_time = time.time()
backward_b(loss2, x2, model)
bwd_b_end_time = time.time()
print(f"loss_2 bwd B runtime {bwd_b_end_time - bwd_b_start_time}")
# dw1
torch.cuda.synchronize()
bwd_w_start_time = time.time()
backward_w(loss1, model)
bwd_w_end_time = time.time()
print(f"loss_1 bwd W runtime {bwd_w_end_time - bwd_w_start_time}")
for p in model.parameters():
assert p.grad is not None
# common bwd 1
torch.cuda.synchronize()
comm_bwd_start_time = time.time()
ref_loss1.backward()
comm_bwd_end_time = time.time()
print(f"loss_1 comm bwd runtime {comm_bwd_end_time - comm_bwd_start_time}")
# # assert dx1 & dw1 == bwd 1
# assert_close(x1.grad, ref_x1.grad)
# for p1, p2 in zip(model.parameters(), ref_model.parameters()):
# assert_close(p1, p2)
# assert_close(p1.grad, p2.grad)
# dw2
torch.cuda.synchronize()
bwd_w_start_time = time.time()
backward_w(loss2, model)
bwd_w_end_time = time.time()
print(f"loss_2 bwd W runtime {bwd_w_end_time - bwd_w_start_time}")
# common bwd 2
torch.cuda.synchronize()
comm_bwd_start_time = time.time()
ref_loss2.backward()
comm_bwd_end_time = time.time()
print(f"loss_2 comm bwd runtime {comm_bwd_end_time - comm_bwd_start_time}")
# # assert dx2 & dw2 == bwd 2
# assert_close(x2.grad, ref_x2.grad)
# for p1, p2 in zip(model.parameters(), ref_model.parameters()):
# print(f"bwd2:\n p1 {p1.grad},\n p2 {p2.grad}\n")
# assert_close(p1, p2)
# assert_close(p1.grad, p2.grad)
def test_double_dx_dw_split_sync():
device = "cuda:0"
model = nn.Linear(8, 8, bias=None).to(device=device)
print(f"model size {get_model_numel(model)} ") # 4GB
x1 = torch.rand(8, 8).to(device=device)
x2 = torch.rand(8, 8).to(device=device)
# x1 = torch.ones(8, 8).to(device=device)
# x2 = torch.ones(8, 8).to(device=device)
ref_model = deepcopy(model)
ref_x1 = x1.clone()
ref_x2 = x2.clone()
x1.requires_grad_()
x2.requires_grad_()
ref_x1.requires_grad_()
ref_x2.requires_grad_()
############
# step1:
############
# loss1
loss1 = model(x1).sum()
# ref_loss1
ref_model(ref_x1).sum()
# dx1
backward_b(loss1, x1, model)
for p in model.parameters():
assert p.grad is None
assert x1.grad is not None
# dw1
backward_w(loss1, model)
for p in model.parameters():
assert p.grad is not None
# common bwd 1
# ref_loss1.backward()
# assert dx1 & dw1 == bwd 1
assert_close(x1.grad, ref_x1.grad)
for p1, p2 in zip(model.parameters(), ref_model.parameters()):
assert_close(p1, p2)
assert_close(p1.grad, p2.grad)
############
# step2:
############
# loss2
loss2 = model(x2).sum()
# ref_loss2
ref_loss2 = ref_model(ref_x2).sum()
for p1, p2 in zip(model.parameters(), ref_model.parameters()):
print(f"bwd2:\n p1 {p1.grad},\n p2 {p2.grad}\n")
assert_close(p1, p2)
assert_close(p1.grad, p2.grad)
# dx2
backward_b(loss2, x2, model)
# dw2
backward_w(loss2, model)
# common bwd 2
ref_loss2.backward()
# assert dx2 & dw2 == bwd 2
assert_close(x2.grad, ref_x2.grad)
for p1, p2 in zip(model.parameters(), ref_model.parameters()):
print(f"bwd2:\n p1 {p1.grad},\n p2 {p2.grad}\n")
assert_close(p1, p2)
assert_close(p1.grad, p2.grad)
def deallocate_output_tensor(out):
"""Pseudo-deallocate (i.e., set to scalar) the output tensor's '.data' field.
This method should be called right after the output tensor has been
sent to the next pipeline stage. At this point, the output tensor is
only useful for its '.grad_fn' field, and not its '.data'.
"""
assert isinstance(out, torch.Tensor), "expected Tensor, found %s." % type(out).__name__
assert out._base is None, "counter-productive to free a view of another tensor."
out.data = torch.empty(
(1,),
device=out.device,
dtype=out.dtype,
)
IN_DIM = 8192
OUT_DIM = 8192
NUM_LAYER = 3
class MlpModel(nn.Module):
def __init__(self):
super().__init__()
self.layers = nn.ModuleList([nn.Linear(IN_DIM, OUT_DIM, bias=None) for _ in range(NUM_LAYER)])
def forward(self, x):
for layer in self.layers:
x = layer(x)
return x
class Attention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.0, proj_drop=0.0, with_qkv=True):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim**-0.5
self.with_qkv = with_qkv
if self.with_qkv:
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
self.attn_drop = nn.Dropout(attn_drop)
def forward(self, x):
B, N, C = x.shape
if self.with_qkv:
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
else:
qkv = x.reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
q, k, v = qkv, qkv, qkv
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
if self.with_qkv:
x = self.proj(x)
x = self.proj_drop(x)
return x
def mem_dx_dw():
device = "cuda:0"
# model = nn.Linear(IN_DIM, OUT_DIM, bias=None).to(device=device)
print(f"Before init Model: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
model = MlpModel().to(device=device)
print(f"model numel {get_model_numel(model)}") # 4GB
print(f"After init Model: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
print(f"Before init x1&2&3: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
x1 = torch.rand(IN_DIM, OUT_DIM).to(device=device)
x2 = torch.rand(IN_DIM, OUT_DIM).to(device=device)
x3 = torch.rand(IN_DIM, OUT_DIM).to(device=device)
x1.requires_grad_()
x2.requires_grad_()
x3.requires_grad_()
print(f"After init x1&2&3: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
############
# step1:
############
print(f"\nStep1")
# loss1
print(f"Before Fwd x1: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
y1 = model(x1)
print(f"After Fwd x1: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
print(f"Before loss1: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
loss1 = y1.sum()
print(f"After loss1: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
# dx1
backward_b(loss1, x1, model)
# dw1
backward_w(loss1, model)
deallocate_output_tensor(x1)
deallocate_output_tensor(y1)
# del x1
# del y1
print(f"After del x1&y1: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
# print(f"\n Step1:collect:{gc.collect()}")
# print(f"object: {gc.get_objects()}")
# print(f"garbage: {gc.garbage}")
############
# step2:
############
print(f"\nStep2")
# loss2
y2 = model(x2)
loss2 = y2.sum()
# dx2
backward_b(loss2, x2, model)
# dw2
backward_w(loss2, model)
deallocate_output_tensor(x2)
deallocate_output_tensor(y2)
# del x2
# del y2
print(f"After del x2&y2: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
print(f"\n Step2:collect:{gc.collect()}")
# print(f"object: {gc.get_objects()}")
print(f"garbage: {gc.garbage}")
############
# step3:
############
print(f"\nStep3")
# loss3
y3 = model(x3)
loss3 = y3.sum()
# dx2
backward_b(loss3, x3, model)
# dw2
backward_w(loss3, model)
deallocate_output_tensor(x3)
deallocate_output_tensor(y3)
# del x3
# del y3
print(f"After del x3&y3: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
print(f"\n Step3:collect:{gc.collect()}")
# print(f"object: {gc.get_objects()}")
print(f"garbage: {gc.garbage}")
# del activation
def activation_dx_dw():
device = "cuda:0"
# model = nn.Linear(IN_DIM, OUT_DIM, bias=None).to(device=device)
print(f"Before init Model: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
model = MlpModel().to(device=device)
x1 = torch.rand(IN_DIM, OUT_DIM).to(device=device)
x2 = torch.rand(IN_DIM, OUT_DIM).to(device=device)
x3 = torch.rand(IN_DIM, OUT_DIM).to(device=device)
x1.requires_grad_()
x2.requires_grad_()
x3.requires_grad_()
print(f"After init Model, x1,x2,x3: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
activations = {}
def register_hooks(module):
def activation_hook(module, input, output):
activations[f"{module.__class__.__name__}_{id(module)}"] = output.detach()
def bwd_hook(module, grad_input, grad_output):
del activations[f"{module.__class__.__name__}_{id(module)}"]
module.register_forward_hook(activation_hook)
module.register_backward_hook(bwd_hook)
model.apply(register_hooks)
############
# step1:
############
print(f"\nStep1")
# loss1
loss1 = model(x1).sum()
# dx1
backward_b(loss1, x1, model)
# dw1
backward_w(loss1, model)
del loss1, x1
print(f"After del x1&y1: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
############
# step2:
############
print(f"\nStep2")
# loss2
loss2 = model(x2).sum()
# dx2
backward_b(loss2, x2, model)
# dw2
backward_w(loss2, model)
# deallocate_output_tensor(x2)
# deallocate_output_tensor(loss2)
del x2, loss2
print(f"After del x2&y2: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
############
# step3:
############
print(f"\nStep3")
# loss3
loss3 = model(x3).sum()
# dx2
backward_b(loss3, x3, model)
# dw2
backward_w(loss3, model)
del x3, loss3
print(f"After del x3&y3: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
# text dx dw in model chunk
def model_chunk_dx_dw():
device = "cuda:0"
num_layers = 4
print(f"Before init Model: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
model = MlpModel(in_dim=4096, out_dim=4096, num_layers=num_layers).to(device=device)
x = torch.rand(4096, 4096).to(device=device)
x.requires_grad_()
model_chunk_0 = torch.nn.ModuleList() # for layer 1 & 2
model_chunk_1 = torch.nn.ModuleList() # for layer 3 & 4
for idx, sub_model in enumerate(model.layers):
if idx < 2:
model_chunk_0.append(sub_model).cuda()
else:
model_chunk_1.append(sub_model).cuda()
print(f"After init Model & input: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
# Step1:chunk 0 fwd
activation = dict() # layer_id: activation
out = x
for i in range(len(model_chunk_0)):
layer = model_chunk_0[i]
activation[i] = layer(out)
print(f"After chunk0 fwd: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
# Step2:chunk 1 fwd
for i in range(len(model_chunk_1)):
layer = model_chunk_0[i]
activation[i + 2] = layer(out)
print(f"After chunk1 fwd: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
# Step3:chunk 1 bwd b: dx=w*dy & bwd w:dw=x*dy
# visit layer reversely
for i in range(len(model_chunk_1) - 1, -1, -1):
layer = model_chunk_1[i]
global_layer_idx = i + 2
prev_global_layer_idx = i + 1 if i + 1 > 0 else None
i + 3 if i + 3 < 4 else None
# bwd b
if global_layer_idx == num_layers - 1: # last layer in last chunk; calculate loss
loss = activation[global_layer_idx].sum()
x = activation[prev_global_layer_idx]
backward_b(loss, x, layer)
else:
loss = activation[global_layer_idx].sum()
x = activation[prev_global_layer_idx]
backward_b(loss, x, layer)
# bwd w
backward_w(loss, layer)
def test_dx_dw_linear_benchmark():
device = "cuda:0"
model = nn.Linear(4096, 4096, bias=None).to(device=device)
# print(f"model numel {get_model_numel(model)}") # 4GB
x1 = torch.rand(4096, 4096).to(device=device)
# x2 = torch.rand(4096, 4096).to(device=device)
ref_model = deepcopy(model)
ref_x1 = x1.clone()
# ref_x2 = x1.clone()
# first step
x1.requires_grad_()
# x2.requires_grad_()
ref_x1.requires_grad_()
# ref_x2.requires_grad_()
# loss for dx_dw bwd
loss1 = model(x1).sum()
# loss2 = model(x2).sum()
# loss for common bwd
ref_model(ref_x1).sum()
# ref_loss2 = ref_model(ref_x2).sum()
with torch.profiler.profile(
activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA],
# schedule=torch.profiler.schedule(wait=1, warmup=1, active=3, repeat=1),
on_trace_ready=torch.profiler.tensorboard_trace_handler(
f"/home/nvme-share/home/duanjunwen/ColossalAI/tests/test_pipeline/test_schedule"
),
record_shapes=True,
profile_memory=True,
with_stack=True,
with_flops=True,
) as prof:
# dx1
torch.cuda.synchronize()
bwd_b_start_time = time.time()
backward_b(loss1, x1, model)
bwd_b_end_time = time.time()
print(f"loss_1 bwd B runtime {bwd_b_end_time - bwd_b_start_time}")
for p in model.parameters():
assert p.grad is None
assert x1.grad is not None
# dw1
torch.cuda.synchronize()
bwd_w_start_time = time.time()
backward_w(loss1, model)
bwd_w_end_time = time.time()
print(f"loss_1 bwd W runtime {bwd_w_end_time - bwd_w_start_time}")
for p in model.parameters():
assert p.grad is not None
# # common bwd 1
# torch.cuda.synchronize()
# comm_bwd_start_time = time.time()
# ref_loss1.backward()
# comm_bwd_end_time = time.time()
# print(f"loss_1 comm bwd runtime {comm_bwd_end_time - comm_bwd_start_time}")
def test_dx_dw_attn_benchmark():
device = "cuda:0"
model = Attention(dim=4096).to(device=device)
# print(f"model numel {get_model_numel(model)}") # 4GB
x1 = torch.rand(1, 256, 4096).to(device=device)
# x2 = torch.rand(1, 256, 4096).to(device=device)
ref_model = deepcopy(model)
ref_x1 = x1.clone()
# ref_x2 = x1.clone()
# first step
x1.requires_grad_()
# x2.requires_grad_()
ref_x1.requires_grad_()
# ref_x2.requires_grad_()
# loss for dx_dw bwd
loss1 = model(x1).sum()
# loss2 = model(x2).sum()
# loss for common bwd
ref_model(ref_x1).sum()
# ref_loss2 = ref_model(ref_x2).sum()
with torch.profiler.profile(
activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA],
# schedule=torch.profiler.schedule(wait=1, warmup=1, active=3, repeat=1),
on_trace_ready=torch.profiler.tensorboard_trace_handler(
f"/home/nvme-share/home/duanjunwen/ColossalAI/tests/test_pipeline/test_schedule"
),
record_shapes=True,
profile_memory=True,
with_stack=True,
with_flops=True,
) as prof:
# dx1
torch.cuda.synchronize()
bwd_b_start_time = time.time()
backward_b(loss1, x1, model)
bwd_b_end_time = time.time()
print(f"loss_1 bwd B runtime {bwd_b_end_time - bwd_b_start_time}")
for p in model.parameters():
assert p.grad is None
assert x1.grad is not None
# dw1
torch.cuda.synchronize()
bwd_w_start_time = time.time()
backward_w(loss1, model)
bwd_w_end_time = time.time()
print(f"loss_1 bwd W runtime {bwd_w_end_time - bwd_w_start_time}")
for p in model.parameters():
assert p.grad is not None
# # common bwd 1
# torch.cuda.synchronize()
# comm_bwd_start_time = time.time()
# ref_loss1.backward()
# comm_bwd_end_time = time.time()
# print(f"loss_1 comm bwd runtime {comm_bwd_end_time - comm_bwd_start_time}")
if __name__ == "__main__":
# test_dx_dw_split()
# test_double_dx_dw_split_nsync()
# test_double_dx_dw_split_sync()
# mem_dx_dw()
# activation_dx_dw()
# test_dx_dw_linear_benchmark()
test_dx_dw_attn_benchmark()

2
tests/test_shardformer/test_model/test_shard_llama.py

@ -277,7 +277,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
"precision": "fp16",
"initial_scale": 1,
},
# TODO: assert layer error
# # TODO: assert layer error
# {
# "tp_size": 2,
# "pp_size": 2,

Loading…
Cancel
Save