From e76308c6e65cb73cc5b20936bd232ba7390c6b11 Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Wed, 16 Oct 2024 03:25:04 +0000 Subject: [PATCH] [fix] rm use_zbv flag in Shardconfig; rm debug info; --- .../booster/plugin/hybrid_parallel_plugin.py | 1 - .../plugin/moe_hybrid_parallel_plugin.py | 1 - colossalai/shardformer/policies/llama.py | 24 +- colossalai/shardformer/policies/mixtral.py | 28 +- colossalai/shardformer/shard/shard_config.py | 1 - examples/language/llama/benchmark.py | 2 - .../test_schedule/test_zerobubble_pp.py | 176 ++++- tests/test_pipeline/test_schedule/zbv_poc.py | 628 ------------------ .../test_model/test_shard_llama.py | 2 +- 9 files changed, 212 insertions(+), 651 deletions(-) delete mode 100644 tests/test_pipeline/test_schedule/zbv_poc.py diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index bba943f12..caeed5457 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -1201,7 +1201,6 @@ class HybridParallelPlugin(PipelinePluginBase): gradient_checkpoint_config=gradient_checkpoint_config, fp8_communication=fp8_communication, inner_ring_size=inner_ring_size, - use_zbv=(pp_style == "zbv"), ) self.amp_config = dict( initial_scale=initial_scale, diff --git a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py index b7e65c6a2..8b62a1e2b 100644 --- a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py @@ -373,7 +373,6 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): make_vocab_size_divisible_by=make_vocab_size_divisible_by, gradient_checkpoint_config=gradient_checkpoint_config, fp8_communication=fp8_communication, - use_zbv=(pp_style == "zbv"), ) self.amp_config = dict( initial_scale=initial_scale, diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index 5c68d0c5e..db4515d7e 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -60,6 +60,11 @@ class LlamaPolicy(Policy): else: norm_cls = RMSNorm + if self.pipeline_stage_manager: + use_zbv = self.pipeline_stage_manager.use_zbv + else: + use_zbv = False + sp_mode = self.shard_config.sequence_parallelism_mode or None sp_size = self.shard_config.sequence_parallel_size or None sp_group = self.shard_config.sequence_parallel_process_group or None @@ -129,7 +134,7 @@ class LlamaPolicy(Policy): kwargs=dict( seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication, - use_zbv=self.shard_config.use_zbv, + use_zbv=use_zbv, ), ), SubModuleReplacementDescription( @@ -138,7 +143,7 @@ class LlamaPolicy(Policy): kwargs=dict( seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication, - use_zbv=self.shard_config.use_zbv, + use_zbv=use_zbv, ), ), SubModuleReplacementDescription( @@ -147,7 +152,7 @@ class LlamaPolicy(Policy): kwargs=dict( seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication, - use_zbv=self.shard_config.use_zbv, + use_zbv=use_zbv, ), ), SubModuleReplacementDescription( @@ -156,7 +161,7 @@ class LlamaPolicy(Policy): kwargs=dict( seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication, - use_zbv=self.shard_config.use_zbv, + use_zbv=use_zbv, ), ), SubModuleReplacementDescription( @@ -165,7 +170,7 @@ class LlamaPolicy(Policy): kwargs=dict( seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication, - use_zbv=self.shard_config.use_zbv, + use_zbv=use_zbv, ), ), SubModuleReplacementDescription( @@ -174,7 +179,7 @@ class LlamaPolicy(Policy): kwargs=dict( seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication, - use_zbv=self.shard_config.use_zbv, + use_zbv=use_zbv, ), ), SubModuleReplacementDescription( @@ -183,7 +188,7 @@ class LlamaPolicy(Policy): kwargs=dict( seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication, - use_zbv=self.shard_config.use_zbv, + use_zbv=use_zbv, ), ), ], @@ -413,6 +418,10 @@ class LlamaForSequenceClassificationPolicy(LlamaPolicy): from transformers import LlamaForSequenceClassification policy = super().module_policy() + if self.pipeline_stage_manager: + use_zbv = self.pipeline_stage_manager.use_zbv + else: + use_zbv = False if self.shard_config.enable_tensor_parallelism: # add a new item for sequence classification @@ -425,6 +434,7 @@ class LlamaForSequenceClassificationPolicy(LlamaPolicy): kwargs=dict( gather_output=True, fp8_communication=self.shard_config.fp8_communication, + use_zbv=use_zbv, ), ) ] diff --git a/colossalai/shardformer/policies/mixtral.py b/colossalai/shardformer/policies/mixtral.py index de546b3c5..11291169a 100644 --- a/colossalai/shardformer/policies/mixtral.py +++ b/colossalai/shardformer/policies/mixtral.py @@ -52,6 +52,10 @@ class MixtralPolicy(Policy): sp_group = self.shard_config.sequence_parallel_process_group or None sp_partial_derived = sp_mode in ["split_gather", "ring"] tp_size = self.shard_config.tensor_parallel_size + if self.pipeline_stage_manager: + use_zbv = self.pipeline_stage_manager.use_zbv + else: + use_zbv = False # modified for both SP and TP num_q_heads = self.model.config.num_attention_heads @@ -126,7 +130,7 @@ class MixtralPolicy(Policy): target_module=Linear1D_Col, kwargs={ "fp8_communication": self.shard_config.fp8_communication, - "use_zbv": self.shard_config.use_zbv, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -134,7 +138,7 @@ class MixtralPolicy(Policy): target_module=Linear1D_Col, kwargs={ "fp8_communication": self.shard_config.fp8_communication, - "use_zbv": self.shard_config.use_zbv, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -142,7 +146,7 @@ class MixtralPolicy(Policy): target_module=Linear1D_Col, kwargs={ "fp8_communication": self.shard_config.fp8_communication, - "use_zbv": self.shard_config.use_zbv, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -150,7 +154,7 @@ class MixtralPolicy(Policy): target_module=Linear1D_Row, kwargs={ "fp8_communication": self.shard_config.fp8_communication, - "use_zbv": self.shard_config.use_zbv, + "use_zbv": use_zbv, }, ), SubModuleReplacementDescription( @@ -159,7 +163,7 @@ class MixtralPolicy(Policy): kwargs={ "gather_output": True, "fp8_communication": self.shard_config.fp8_communication, - "use_zbv": self.shard_config.use_zbv, + "use_zbv": use_zbv, }, ), ], @@ -195,7 +199,7 @@ class MixtralPolicy(Policy): "tp_group": self.shard_config.tensor_parallel_process_group, "moe_dp_group": self.shard_config.moe_dp_group, "fp8_communication": self.shard_config.fp8_communication, - "use_zbv": self.shard_config.use_zbv, + "use_zbv": use_zbv, }, ) ], @@ -330,6 +334,10 @@ class MixtralModelPolicy(MixtralPolicy): class MixtralForCausalLMPolicy(MixtralPolicy): def module_policy(self): policy = super().module_policy() + if self.pipeline_stage_manager: + use_zbv = self.pipeline_stage_manager.use_zbv + else: + use_zbv = False # TODO: assign pg mesh from plugin to all modules if self.shard_config.enable_tensor_parallelism: # add a new item for causal lm @@ -342,7 +350,7 @@ class MixtralForCausalLMPolicy(MixtralPolicy): kwargs=dict( gather_output=True, fp8_communication=self.shard_config.fp8_communication, - use_zbv=self.shard_config.use_zbv, + use_zbv=use_zbv, ), ) ], @@ -392,6 +400,10 @@ class MixtralForSequenceClassificationPolicy(MixtralPolicy): from transformers import MixtralForSequenceClassification policy = super().module_policy() + if self.pipeline_stage_manager: + use_zbv = self.pipeline_stage_manager.use_zbv + else: + use_zbv = False if self.shard_config.enable_tensor_parallelism: # add a new item for sequence classification @@ -404,7 +416,7 @@ class MixtralForSequenceClassificationPolicy(MixtralPolicy): kwargs=dict( gather_output=True, fp8_communication=self.shard_config.fp8_communication, - use_zbv=self.shard_config.use_zbv, + use_zbv=use_zbv, ), ) ] diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index 33e93fa51..1219119bb 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -49,7 +49,6 @@ class ShardConfig: make_vocab_size_divisible_by: int = 64 gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None extra_kwargs: Dict[str, Any] = field(default_factory=dict) - use_zbv: bool = False # For ring attention inner_ring_size: Optional[int] = None diff --git a/examples/language/llama/benchmark.py b/examples/language/llama/benchmark.py index 4f2c45d75..041c51fb1 100644 --- a/examples/language/llama/benchmark.py +++ b/examples/language/llama/benchmark.py @@ -5,8 +5,6 @@ import warnings from contextlib import nullcontext import torch - -torch.autograd.set_detect_anomaly(True) import torch.distributed as dist from data_utils import RandomDataset from model_utils import format_numel_str, get_model_numel diff --git a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py index 1a1fbbeb2..bdc539043 100644 --- a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py +++ b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py @@ -8,12 +8,14 @@ import torch import torch.distributed as dist import torch.nn as nn from torch.testing import assert_close +from transformers.models.llama.configuration_llama import LlamaConfig +from transformers.models.llama.modeling_llama import LlamaModel from transformers.models.mixtral.configuration_mixtral import MixtralConfig from transformers.models.mixtral.modeling_mixtral import MixtralModel import colossalai from colossalai.booster.booster import Booster -from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin +from colossalai.booster.plugin.moe_hybrid_parallel_plugin import HybridParallelPlugin, MoeHybridParallelPlugin from colossalai.cluster import ProcessGroupMesh from colossalai.interface import OptimizerWrapper from colossalai.logging import disable_existing_loggers @@ -918,11 +920,181 @@ def run_with_booster_moehybridplugin(config: Tuple[int, ...]): torch.cuda.empty_cache() +@parameterize( + "config", + [ + (0, 4, 1, 1), + # (1, 2, 2, 1), + # (1, 2, 1, 2), + # (1, 1, 2, 2), + ], +) +def run_with_booster_hybridplugin(config: Tuple[int, ...]): + stage, pp_size, tp_size, sp_size = config + num_microbatches = pp_size + dist.get_world_size() + rank = dist.get_rank() + dtype, precision = torch.float16, "fp16" + torch.cuda.set_device(dist.get_rank()) + + ######## + # init base model + ######## + assert pp_size <= NUM_LAYERS, "pp_size should be less than or equal to NUM_LAYERS" + config = LlamaConfig( + hidden_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS, + intermediate_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS * 2, + num_hidden_layers=NUM_LAYERS, + num_attention_heads=NUM_HEADS, + num_key_value_heads=NUM_HEADS, + attn_implementation="flash_attention_2", + ) + + # init model with the same seed + seed_all(10086) + + torch_model = LlamaModel(config).to(dtype).cuda() + # TODO: Support MixtralForCausalLM + # torch_model = MixtralForCausalLM(config).to(dtype).cuda() + torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1) + # init schedule + h, a, s = config.hidden_size, config.num_attention_heads, 1024 + mem_f = 34 * h + 5 * a * s + mem_w = -32 * h + mem_b = -mem_w - mem_f + graph = PipelineGraph( + n_stage=pp_size, + n_micro=num_microbatches, + f_cost=1, + b_cost=1, + w_cost=1, + c_cost=1, + f_mem=mem_f, + b_mem=mem_b, + w_mem=mem_w, + ) + + zbv_schedule = graph.get_v_schedule() + + # init MoeHybridPlugin + plugin = HybridParallelPlugin( + pp_size=pp_size, + num_microbatches=pp_size, + tp_size=tp_size, + sp_size=sp_size, + zero_stage=stage, + enable_sequence_parallelism=sp_size > 1, + sequence_parallelism_mode="all_to_all" if sp_size > 1 else None, + overlap_communication=False, + initial_scale=1, + precision=precision, + find_unused_parameters=True, + pp_style="zbv", + scheduler_nodes=zbv_schedule, + num_model_chunks=2, + ) + + dp_size = plugin.dp_size + + booster = Booster(plugin=plugin) + + ######## + # init pp model + ######## + + parallel_model = deepcopy(torch_model) + parallel_optimizer = torch.optim.SGD(parallel_model.parameters(), lr=1) + parallel_model, parallel_optimizer, _, _, _ = booster.boost(parallel_model, parallel_optimizer) + # create different input along dp axis + seed_all(1453 + rank) + + torch_model.train() + parallel_model.train() + for i in range(2): + # gen random input + # input = torch.rand( + # NUM_BATCH, NUM_TOK_PER_BATCH, NUM_HEADS, HIDDEN_SIZE_PER_HEAD, requires_grad=True + # ).cuda() + input_ids = torch.randint(0, torch_model.vocab_size, (NUM_BATCH, config.max_position_embeddings)).cuda() + attention_mask = torch.ones_like(input_ids).cuda() + input_ids.clone().cuda() + input_data = {"input_ids": input_ids, "attention_mask": attention_mask} + + # dist.all_reduce( + # input, group=plugin.pp_group + # ) # pp inputs except the first stage doesn't matter, but need to be replicate for torch model check + + # dist.all_reduce(input, group=plugin.tp_group) # tp group duplicate input + # dist.all_reduce(input, group=plugin.sp_group) # sp group duplicate input + + # run the model with hybrid parallel + if booster.plugin.stage_manager is not None: + # for test with pp + data_iter = iter([input_data]) + sharded_output = booster.execute_pipeline( + data_iter, + parallel_model, + lambda x, y: x.last_hidden_state.mean(), + parallel_optimizer, + return_loss=True, + return_outputs=True, + ) + # stage 0 chunk 0 + parallel_output = None + if ( + booster.plugin.stage_manager.is_first_stage(ignore_chunk=True) + and rank == dist.get_process_group_ranks(plugin.pp_group)[0] + ): + parallel_output = sharded_output["loss"] + else: + parallel_output = torch.tensor(12345.0, device="cuda") + # broadcast along pp axis + dist.broadcast(parallel_output, src=dist.get_process_group_ranks(plugin.pp_group)[0], group=plugin.pp_group) + + else: + # for test without pp + parallel_output = parallel_model( + input_ids=input_data["input_ids"], + attention_mask=input_data["attention_mask"], + ).last_hidden_state.mean() + parallel_optimizer.backward(parallel_output) + parallel_optimizer.step() + parallel_optimizer.zero_grad() + dist.all_reduce(parallel_output, group=plugin.dp_group) + + # =================================================================================== + # run normal model with all dp(different) inputs + all_inputs = [input_data for _ in range(dp_size)] + # dist.all_gather(all_inputs, input, group=plugin.dp_group) + torch_output_sum = 0 + for input_data_ in all_inputs: + torch_output = torch_model( + input_ids=input_data_["input_ids"], + attention_mask=input_data_["attention_mask"], + ).last_hidden_state.mean() + torch_output.backward() + torch_output_sum += torch_output.detach() + # print(f"parallel_output {parallel_output} torch_output_sum {torch_output_sum}") + # avg dp grads follows zero optimizer + for p in torch_model.parameters(): + if p.grad is not None: + p.grad /= dp_size + torch_optimizer.step() + torch_optimizer.zero_grad() + + print(f"loop {i} rank {dist.get_rank()} parallel_output {parallel_output} torch_output_sum {torch_output_sum}") + # assert_loose_close(parallel_output, torch_output_sum, dtype=dtype) + # print(f"rank {dist.get_rank()} config {test_config} test passed") + clear_layout_converter() + Randomizer.reset_index() + torch.cuda.empty_cache() + + def run_dist(rank, world_size, port): disable_existing_loggers() colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") - # run_fwd_bwd_vschedule_with_optim() run_with_booster_moehybridplugin() + # run_with_booster_hybridplugin() @pytest.mark.dist diff --git a/tests/test_pipeline/test_schedule/zbv_poc.py b/tests/test_pipeline/test_schedule/zbv_poc.py deleted file mode 100644 index 6280990a9..000000000 --- a/tests/test_pipeline/test_schedule/zbv_poc.py +++ /dev/null @@ -1,628 +0,0 @@ -import gc -import time -from copy import deepcopy - -import torch -import torch.nn as nn -from torch.testing import assert_close - - -def get_model_numel(model): - return sum(p.numel() for p in model.parameters()) / 1024**2 - - -# Step1: dx = w*dy -def backward_b(loss, x, model): - torch.autograd.backward(loss, inputs=x, retain_graph=True) - - -# Step2: dummy dw = x*dy -def backward_w(loss, model): - torch.autograd.backward(loss, inputs=list(model.parameters())) - - -def test_double_dx_dw_split_nsync(): - device = "cuda:0" - model = nn.Linear(4096, 4096, bias=None).to(device=device) - # print(f"model numel {get_model_numel(model)}") # 4GB - x1 = torch.rand(4096, 4096).to(device=device) - x2 = torch.rand(4096, 4096).to(device=device) - ref_model = deepcopy(model) - ref_x1 = x1.clone() - ref_x2 = x1.clone() - - # first step - x1.requires_grad_() - x2.requires_grad_() - ref_x1.requires_grad_() - ref_x2.requires_grad_() - - # loss for dx_dw bwd - loss1 = model(x1).sum() - loss2 = model(x2).sum() - - # loss for common bwd - ref_loss1 = ref_model(ref_x1).sum() - ref_loss2 = ref_model(ref_x2).sum() - - # dx1 - torch.cuda.synchronize() - bwd_b_start_time = time.time() - backward_b(loss1, x1, model) - bwd_b_end_time = time.time() - print(f"loss_1 bwd B runtime {bwd_b_end_time - bwd_b_start_time}") - - for p in model.parameters(): - assert p.grad is None - assert x1.grad is not None - - # dx2 - torch.cuda.synchronize() - bwd_b_start_time = time.time() - backward_b(loss2, x2, model) - bwd_b_end_time = time.time() - print(f"loss_2 bwd B runtime {bwd_b_end_time - bwd_b_start_time}") - - # dw1 - torch.cuda.synchronize() - bwd_w_start_time = time.time() - backward_w(loss1, model) - bwd_w_end_time = time.time() - print(f"loss_1 bwd W runtime {bwd_w_end_time - bwd_w_start_time}") - for p in model.parameters(): - assert p.grad is not None - - # common bwd 1 - torch.cuda.synchronize() - comm_bwd_start_time = time.time() - ref_loss1.backward() - comm_bwd_end_time = time.time() - print(f"loss_1 comm bwd runtime {comm_bwd_end_time - comm_bwd_start_time}") - - # # assert dx1 & dw1 == bwd 1 - # assert_close(x1.grad, ref_x1.grad) - # for p1, p2 in zip(model.parameters(), ref_model.parameters()): - # assert_close(p1, p2) - # assert_close(p1.grad, p2.grad) - - # dw2 - torch.cuda.synchronize() - bwd_w_start_time = time.time() - backward_w(loss2, model) - bwd_w_end_time = time.time() - print(f"loss_2 bwd W runtime {bwd_w_end_time - bwd_w_start_time}") - - # common bwd 2 - torch.cuda.synchronize() - comm_bwd_start_time = time.time() - ref_loss2.backward() - comm_bwd_end_time = time.time() - print(f"loss_2 comm bwd runtime {comm_bwd_end_time - comm_bwd_start_time}") - - # # assert dx2 & dw2 == bwd 2 - # assert_close(x2.grad, ref_x2.grad) - # for p1, p2 in zip(model.parameters(), ref_model.parameters()): - # print(f"bwd2:\n p1 {p1.grad},\n p2 {p2.grad}\n") - # assert_close(p1, p2) - # assert_close(p1.grad, p2.grad) - - -def test_double_dx_dw_split_sync(): - device = "cuda:0" - model = nn.Linear(8, 8, bias=None).to(device=device) - print(f"model size {get_model_numel(model)} ") # 4GB - x1 = torch.rand(8, 8).to(device=device) - x2 = torch.rand(8, 8).to(device=device) - - # x1 = torch.ones(8, 8).to(device=device) - # x2 = torch.ones(8, 8).to(device=device) - - ref_model = deepcopy(model) - ref_x1 = x1.clone() - ref_x2 = x2.clone() - - x1.requires_grad_() - x2.requires_grad_() - ref_x1.requires_grad_() - ref_x2.requires_grad_() - - ############ - # step1: - ############ - - # loss1 - loss1 = model(x1).sum() - - # ref_loss1 - ref_model(ref_x1).sum() - - # dx1 - backward_b(loss1, x1, model) - for p in model.parameters(): - assert p.grad is None - assert x1.grad is not None - - # dw1 - backward_w(loss1, model) - for p in model.parameters(): - assert p.grad is not None - - # common bwd 1 - # ref_loss1.backward() - - # assert dx1 & dw1 == bwd 1 - assert_close(x1.grad, ref_x1.grad) - for p1, p2 in zip(model.parameters(), ref_model.parameters()): - assert_close(p1, p2) - assert_close(p1.grad, p2.grad) - - ############ - # step2: - ############ - - # loss2 - loss2 = model(x2).sum() - - # ref_loss2 - ref_loss2 = ref_model(ref_x2).sum() - - for p1, p2 in zip(model.parameters(), ref_model.parameters()): - print(f"bwd2:\n p1 {p1.grad},\n p2 {p2.grad}\n") - assert_close(p1, p2) - assert_close(p1.grad, p2.grad) - - # dx2 - backward_b(loss2, x2, model) - - # dw2 - backward_w(loss2, model) - - # common bwd 2 - ref_loss2.backward() - - # assert dx2 & dw2 == bwd 2 - assert_close(x2.grad, ref_x2.grad) - for p1, p2 in zip(model.parameters(), ref_model.parameters()): - print(f"bwd2:\n p1 {p1.grad},\n p2 {p2.grad}\n") - assert_close(p1, p2) - assert_close(p1.grad, p2.grad) - - -def deallocate_output_tensor(out): - """Pseudo-deallocate (i.e., set to scalar) the output tensor's '.data' field. - - This method should be called right after the output tensor has been - sent to the next pipeline stage. At this point, the output tensor is - only useful for its '.grad_fn' field, and not its '.data'. - """ - assert isinstance(out, torch.Tensor), "expected Tensor, found %s." % type(out).__name__ - assert out._base is None, "counter-productive to free a view of another tensor." - out.data = torch.empty( - (1,), - device=out.device, - dtype=out.dtype, - ) - - -IN_DIM = 8192 -OUT_DIM = 8192 -NUM_LAYER = 3 - - -class MlpModel(nn.Module): - def __init__(self): - super().__init__() - self.layers = nn.ModuleList([nn.Linear(IN_DIM, OUT_DIM, bias=None) for _ in range(NUM_LAYER)]) - - def forward(self, x): - for layer in self.layers: - x = layer(x) - return x - - -class Attention(nn.Module): - def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.0, proj_drop=0.0, with_qkv=True): - super().__init__() - self.num_heads = num_heads - head_dim = dim // num_heads - self.scale = qk_scale or head_dim**-0.5 - self.with_qkv = with_qkv - if self.with_qkv: - self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) - self.proj = nn.Linear(dim, dim) - self.proj_drop = nn.Dropout(proj_drop) - self.attn_drop = nn.Dropout(attn_drop) - - def forward(self, x): - B, N, C = x.shape - if self.with_qkv: - qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) - q, k, v = qkv[0], qkv[1], qkv[2] - else: - qkv = x.reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) - q, k, v = qkv, qkv, qkv - - attn = (q @ k.transpose(-2, -1)) * self.scale - attn = attn.softmax(dim=-1) - attn = self.attn_drop(attn) - - x = (attn @ v).transpose(1, 2).reshape(B, N, C) - if self.with_qkv: - x = self.proj(x) - x = self.proj_drop(x) - return x - - -def mem_dx_dw(): - device = "cuda:0" - # model = nn.Linear(IN_DIM, OUT_DIM, bias=None).to(device=device) - print(f"Before init Model: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") - model = MlpModel().to(device=device) - print(f"model numel {get_model_numel(model)}") # 4GB - print(f"After init Model: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") - - print(f"Before init x1&2&3: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") - x1 = torch.rand(IN_DIM, OUT_DIM).to(device=device) - x2 = torch.rand(IN_DIM, OUT_DIM).to(device=device) - x3 = torch.rand(IN_DIM, OUT_DIM).to(device=device) - - x1.requires_grad_() - x2.requires_grad_() - x3.requires_grad_() - print(f"After init x1&2&3: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") - - ############ - # step1: - ############ - print(f"\nStep1") - - # loss1 - print(f"Before Fwd x1: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") - y1 = model(x1) - print(f"After Fwd x1: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") - - print(f"Before loss1: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") - loss1 = y1.sum() - print(f"After loss1: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") - - # dx1 - backward_b(loss1, x1, model) - - # dw1 - backward_w(loss1, model) - - deallocate_output_tensor(x1) - deallocate_output_tensor(y1) - # del x1 - # del y1 - print(f"After del x1&y1: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") - - # print(f"\n Step1:collect:{gc.collect()}") - # print(f"object: {gc.get_objects()}") - # print(f"garbage: {gc.garbage}") - - ############ - # step2: - ############ - print(f"\nStep2") - - # loss2 - y2 = model(x2) - loss2 = y2.sum() - - # dx2 - backward_b(loss2, x2, model) - - # dw2 - backward_w(loss2, model) - deallocate_output_tensor(x2) - deallocate_output_tensor(y2) - # del x2 - # del y2 - print(f"After del x2&y2: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") - - print(f"\n Step2:collect:{gc.collect()}") - # print(f"object: {gc.get_objects()}") - print(f"garbage: {gc.garbage}") - - ############ - # step3: - ############ - - print(f"\nStep3") - - # loss3 - y3 = model(x3) - loss3 = y3.sum() - - # dx2 - backward_b(loss3, x3, model) - - # dw2 - backward_w(loss3, model) - - deallocate_output_tensor(x3) - deallocate_output_tensor(y3) - # del x3 - # del y3 - - print(f"After del x3&y3: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") - - print(f"\n Step3:collect:{gc.collect()}") - # print(f"object: {gc.get_objects()}") - print(f"garbage: {gc.garbage}") - - -# del activation -def activation_dx_dw(): - device = "cuda:0" - # model = nn.Linear(IN_DIM, OUT_DIM, bias=None).to(device=device) - print(f"Before init Model: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") - model = MlpModel().to(device=device) - x1 = torch.rand(IN_DIM, OUT_DIM).to(device=device) - x2 = torch.rand(IN_DIM, OUT_DIM).to(device=device) - x3 = torch.rand(IN_DIM, OUT_DIM).to(device=device) - - x1.requires_grad_() - x2.requires_grad_() - x3.requires_grad_() - print(f"After init Model, x1,x2,x3: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") - - activations = {} - - def register_hooks(module): - def activation_hook(module, input, output): - activations[f"{module.__class__.__name__}_{id(module)}"] = output.detach() - - def bwd_hook(module, grad_input, grad_output): - del activations[f"{module.__class__.__name__}_{id(module)}"] - - module.register_forward_hook(activation_hook) - module.register_backward_hook(bwd_hook) - - model.apply(register_hooks) - - ############ - # step1: - ############ - print(f"\nStep1") - - # loss1 - loss1 = model(x1).sum() - - # dx1 - backward_b(loss1, x1, model) - - # dw1 - backward_w(loss1, model) - - del loss1, x1 - print(f"After del x1&y1: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") - - ############ - # step2: - ############ - print(f"\nStep2") - - # loss2 - loss2 = model(x2).sum() - - # dx2 - backward_b(loss2, x2, model) - - # dw2 - backward_w(loss2, model) - - # deallocate_output_tensor(x2) - # deallocate_output_tensor(loss2) - del x2, loss2 - print(f"After del x2&y2: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") - - ############ - # step3: - ############ - print(f"\nStep3") - - # loss3 - loss3 = model(x3).sum() - - # dx2 - backward_b(loss3, x3, model) - - # dw2 - backward_w(loss3, model) - - del x3, loss3 - - print(f"After del x3&y3: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") - - -# text dx dw in model chunk -def model_chunk_dx_dw(): - device = "cuda:0" - num_layers = 4 - print(f"Before init Model: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") - model = MlpModel(in_dim=4096, out_dim=4096, num_layers=num_layers).to(device=device) - x = torch.rand(4096, 4096).to(device=device) - x.requires_grad_() - - model_chunk_0 = torch.nn.ModuleList() # for layer 1 & 2 - model_chunk_1 = torch.nn.ModuleList() # for layer 3 & 4 - - for idx, sub_model in enumerate(model.layers): - if idx < 2: - model_chunk_0.append(sub_model).cuda() - else: - model_chunk_1.append(sub_model).cuda() - - print(f"After init Model & input: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") - - # Step1:chunk 0 fwd - activation = dict() # layer_id: activation - out = x - for i in range(len(model_chunk_0)): - layer = model_chunk_0[i] - activation[i] = layer(out) - print(f"After chunk0 fwd: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") - # Step2:chunk 1 fwd - for i in range(len(model_chunk_1)): - layer = model_chunk_0[i] - activation[i + 2] = layer(out) - print(f"After chunk1 fwd: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") - - # Step3:chunk 1 bwd b: dx=w*dy & bwd w:dw=x*dy - # visit layer reversely - for i in range(len(model_chunk_1) - 1, -1, -1): - layer = model_chunk_1[i] - global_layer_idx = i + 2 - prev_global_layer_idx = i + 1 if i + 1 > 0 else None - i + 3 if i + 3 < 4 else None - - # bwd b - if global_layer_idx == num_layers - 1: # last layer in last chunk; calculate loss - loss = activation[global_layer_idx].sum() - x = activation[prev_global_layer_idx] - backward_b(loss, x, layer) - else: - loss = activation[global_layer_idx].sum() - x = activation[prev_global_layer_idx] - backward_b(loss, x, layer) - - # bwd w - backward_w(loss, layer) - - -def test_dx_dw_linear_benchmark(): - device = "cuda:0" - model = nn.Linear(4096, 4096, bias=None).to(device=device) - # print(f"model numel {get_model_numel(model)}") # 4GB - x1 = torch.rand(4096, 4096).to(device=device) - # x2 = torch.rand(4096, 4096).to(device=device) - ref_model = deepcopy(model) - ref_x1 = x1.clone() - # ref_x2 = x1.clone() - - # first step - x1.requires_grad_() - # x2.requires_grad_() - ref_x1.requires_grad_() - # ref_x2.requires_grad_() - - # loss for dx_dw bwd - loss1 = model(x1).sum() - # loss2 = model(x2).sum() - - # loss for common bwd - ref_model(ref_x1).sum() - # ref_loss2 = ref_model(ref_x2).sum() - - with torch.profiler.profile( - activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA], - # schedule=torch.profiler.schedule(wait=1, warmup=1, active=3, repeat=1), - on_trace_ready=torch.profiler.tensorboard_trace_handler( - f"/home/nvme-share/home/duanjunwen/ColossalAI/tests/test_pipeline/test_schedule" - ), - record_shapes=True, - profile_memory=True, - with_stack=True, - with_flops=True, - ) as prof: - # dx1 - torch.cuda.synchronize() - bwd_b_start_time = time.time() - backward_b(loss1, x1, model) - bwd_b_end_time = time.time() - print(f"loss_1 bwd B runtime {bwd_b_end_time - bwd_b_start_time}") - - for p in model.parameters(): - assert p.grad is None - assert x1.grad is not None - - # dw1 - torch.cuda.synchronize() - bwd_w_start_time = time.time() - backward_w(loss1, model) - bwd_w_end_time = time.time() - print(f"loss_1 bwd W runtime {bwd_w_end_time - bwd_w_start_time}") - for p in model.parameters(): - assert p.grad is not None - - # # common bwd 1 - # torch.cuda.synchronize() - # comm_bwd_start_time = time.time() - # ref_loss1.backward() - # comm_bwd_end_time = time.time() - # print(f"loss_1 comm bwd runtime {comm_bwd_end_time - comm_bwd_start_time}") - - -def test_dx_dw_attn_benchmark(): - device = "cuda:0" - model = Attention(dim=4096).to(device=device) - # print(f"model numel {get_model_numel(model)}") # 4GB - x1 = torch.rand(1, 256, 4096).to(device=device) - # x2 = torch.rand(1, 256, 4096).to(device=device) - ref_model = deepcopy(model) - ref_x1 = x1.clone() - # ref_x2 = x1.clone() - - # first step - x1.requires_grad_() - # x2.requires_grad_() - ref_x1.requires_grad_() - # ref_x2.requires_grad_() - - # loss for dx_dw bwd - loss1 = model(x1).sum() - # loss2 = model(x2).sum() - - # loss for common bwd - ref_model(ref_x1).sum() - # ref_loss2 = ref_model(ref_x2).sum() - - with torch.profiler.profile( - activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA], - # schedule=torch.profiler.schedule(wait=1, warmup=1, active=3, repeat=1), - on_trace_ready=torch.profiler.tensorboard_trace_handler( - f"/home/nvme-share/home/duanjunwen/ColossalAI/tests/test_pipeline/test_schedule" - ), - record_shapes=True, - profile_memory=True, - with_stack=True, - with_flops=True, - ) as prof: - # dx1 - torch.cuda.synchronize() - bwd_b_start_time = time.time() - backward_b(loss1, x1, model) - bwd_b_end_time = time.time() - print(f"loss_1 bwd B runtime {bwd_b_end_time - bwd_b_start_time}") - - for p in model.parameters(): - assert p.grad is None - assert x1.grad is not None - - # dw1 - torch.cuda.synchronize() - bwd_w_start_time = time.time() - backward_w(loss1, model) - bwd_w_end_time = time.time() - print(f"loss_1 bwd W runtime {bwd_w_end_time - bwd_w_start_time}") - for p in model.parameters(): - assert p.grad is not None - - # # common bwd 1 - # torch.cuda.synchronize() - # comm_bwd_start_time = time.time() - # ref_loss1.backward() - # comm_bwd_end_time = time.time() - # print(f"loss_1 comm bwd runtime {comm_bwd_end_time - comm_bwd_start_time}") - - -if __name__ == "__main__": - # test_dx_dw_split() - # test_double_dx_dw_split_nsync() - # test_double_dx_dw_split_sync() - # mem_dx_dw() - # activation_dx_dw() - # test_dx_dw_linear_benchmark() - test_dx_dw_attn_benchmark() diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index ce513f1fd..33707a4f6 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -277,7 +277,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "precision": "fp16", "initial_scale": 1, }, - # TODO: assert layer error + # # TODO: assert layer error # { # "tp_size": 2, # "pp_size": 2,