diff --git a/.github/workflows/build_on_pr.yml b/.github/workflows/build_on_pr.yml index 2cad504f3..b01d15490 100644 --- a/.github/workflows/build_on_pr.yml +++ b/.github/workflows/build_on_pr.yml @@ -201,4 +201,4 @@ jobs: uses: actions/upload-artifact@v3 with: name: report - path: report/ + path: report/ \ No newline at end of file diff --git a/.github/workflows/build_on_schedule.yml b/.github/workflows/build_on_schedule.yml index ae1a5275e..510665b46 100644 --- a/.github/workflows/build_on_schedule.yml +++ b/.github/workflows/build_on_schedule.yml @@ -83,4 +83,4 @@ jobs: SERVER_URL: ${{github.server_url }} REPO: ${{ github.repository }} RUN_ID: ${{ github.run_id }} - WEBHOOK_URL: ${{ secrets.LARK_NOTIFICATION_WEBHOOK_URL }} + WEBHOOK_URL: ${{ secrets.LARK_NOTIFICATION_WEBHOOK_URL }} \ No newline at end of file diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index da67e6b41..bf677e052 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -36,6 +36,8 @@ from .pp_plugin_base import PipelinePluginBase DP_AXIS, PP_AXIS, TP_AXIS = 0, 1, 2 +PRECISION_TORCH_TYPE = {"fp16": torch.float16, "fp32": torch.float32, "bf16": torch.bfloat16} + def _convert_floating_point(x, dtype: torch.dtype = torch.float16): if isinstance(x, torch.Tensor) and torch.is_floating_point(x): @@ -1059,6 +1061,7 @@ class HybridParallelPlugin(PipelinePluginBase): overlap_communication=overlap_communication, cpu_offload=cpu_offload, partition_grad=(self.zero_stage == 2), + forced_dtype=PRECISION_TORCH_TYPE[precision], ) self.max_norm = max_norm diff --git a/colossalai/shardformer/layer/_operation.py b/colossalai/shardformer/layer/_operation.py index 4bca335c8..d4960c7e4 100644 --- a/colossalai/shardformer/layer/_operation.py +++ b/colossalai/shardformer/layer/_operation.py @@ -9,6 +9,7 @@ except: try: import fused_weight_gradient_mlp_cuda + _grad_accum_fusion_available = True except ImportError: _grad_accum_fusion_available = False @@ -78,7 +79,8 @@ class MatmulWithAsyncCommunication(torch.autograd.Function): # In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias. weight = weight.view(weight.shape) - bias = bias.view(bias.shape) + if bias is not None: + bias = bias.view(bias.shape) total_input = input grad_input = grad_output.matmul(weight.T) @@ -91,9 +93,8 @@ class MatmulWithAsyncCommunication(torch.autograd.Function): if ctx.async_grad_allreduce: # Asynchronous all-reduce handle = dist.all_reduce(grad_input, group=ctx.process_group, async_op=True) - # Delay the start of weight gradient computation shortly (3us) to have - # all-reduce scheduled first and have GPU resources allocated - _ = torch.empty(1, device=grad_output.device) + 1 + # Relay on CUDA_DEVICE_MAX_CONNECTIONS=1 to have + # all-reduce scheduled first and have GPU resources allocated, CUDA_DEVICE_MAX_CONNECTIONS=1 is set in shardformer.py grad_weight = total_input.t().matmul(grad_output) grad_bias = grad_output.sum(dim=0) if use_bias else None @@ -115,7 +116,6 @@ class LinearWithAsyncCommunication(torch.autograd.Function): ctx.use_bias = bias is not None ctx.process_group = process_group ctx.async_grad_allreduce = async_grad_allreduce - if bias is not None: output = F.linear(input_, weight, bias) else: @@ -143,9 +143,8 @@ class LinearWithAsyncCommunication(torch.autograd.Function): if ctx.async_grad_allreduce: # Asynchronous all-reduce handle = dist.all_reduce(grad_input, group=ctx.process_group, async_op=True) - # Delay the start of weight gradient computation shortly (3us) to have - # all-reduce scheduled first and have GPU resources allocated - _ = torch.empty(1, device=grad_output.device) + 1 + # Relay on CUDA_DEVICE_MAX_CONNECTIONS=1 to have + # all-reduce scheduled first and have GPU resources allocated, CUDA_DEVICE_MAX_CONNECTIONS=1 is set in shardformer.py if _grad_accum_fusion_available and weight.grad is not None: grad = weight.grad @@ -228,9 +227,8 @@ class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function): input_.shape, dtype=input_parallel.dtype, device=input_parallel.device ).contiguous() handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True) - # Delay the start of weight gradient computation shortly (3us) to have - # reduce-scatter scheduled first and have GPU resources allocated - _ = torch.empty(1, device=grad_output.device) + 1 + # Relay on CUDA_DEVICE_MAX_CONNECTIONS=1 to have + # all-reduce scheduled first and have GPU resources allocated, CUDA_DEVICE_MAX_CONNECTIONS=1 is set in shardformer.py if _grad_accum_fusion_available and weight.grad is not None: grad = weight.grad @@ -394,9 +392,8 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function): input_.shape, dtype=input_parallel.dtype, device=input_parallel.device ).contiguous() handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True) - # Delay the start of weight gradient computation shortly (3us) to have - # reduce-scatter scheduled first and have GPU resources allocated - _ = torch.empty(1, device=grad_output.device) + 1 + # Relay on CUDA_DEVICE_MAX_CONNECTIONS=1 to have + # all-reduce scheduled first and have GPU resources allocated, CUDA_DEVICE_MAX_CONNECTIONS=1 is set in shardformer.py grad_weight = total_input.t().matmul(grad_output) grad_bias = grad_output.sum(dim=0) if use_bias else None @@ -431,7 +428,7 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function): input_parallel = torch.cat(tensor_list, dim=dim).contiguous() # calculate gradient if len(input_parallel.shape) > 2: - input_parallel = input_parallel.view(-1, input_parallel.shape[-1]) + input_parallel = input_parallel.view(-1, input_parallel.shape[-1]) grad_weight = input_parallel.t().matmul(grad_output) # wait until reduce-scatter finished reducescatter_handle.wait() diff --git a/colossalai/shardformer/modeling/gpt2.py b/colossalai/shardformer/modeling/gpt2.py index 055e3096d..3e5cc6015 100644 --- a/colossalai/shardformer/modeling/gpt2.py +++ b/colossalai/shardformer/modeling/gpt2.py @@ -24,6 +24,8 @@ from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward from colossalai.shardformer.shard import ShardConfig +from ..layer import cross_entropy_1d + class GPT2PipelineForwards: """ @@ -326,7 +328,15 @@ class GPT2PipelineForwards: shift_labels = labels[..., 1:].contiguous() # Flatten the tokens loss_fct = CrossEntropyLoss() - loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + shift_logits = shift_logits.view(-1, shift_logits.size(-1)) + shift_labels = shift_labels.view(-1) + if shard_config.enable_tensor_parallelism: + loss = cross_entropy_1d( + shift_logits, shift_labels, process_group=shard_config.tensor_parallel_process_group + ) + else: + loss = loss_fct(shift_logits, shift_labels) + if not return_dict: output = (lm_logits,) + outputs[1:] return ((loss,) + output) if loss is not None else output @@ -1006,3 +1016,84 @@ def gpt2_sequence_parallel_forward_fn(shard_config: ShardConfig): ) return forward + + +def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig): + from transformers import GPT2LMHeadModel + + def forward( + self: GPT2LMHeadModel, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set + `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` + are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.transformer( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + + lm_logits = self.lm_head(hidden_states) + + loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(lm_logits.device) + # Shift so that tokens < n predict n + shift_logits = lm_logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, shift_logits.size(-1)) + shift_labels = shift_labels.view(-1) + if shard_config.enable_tensor_parallelism: + loss = cross_entropy_1d( + shift_logits, shift_labels, process_group=shard_config.tensor_parallel_process_group + ) + else: + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (lm_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=loss, + logits=lm_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + cross_attentions=transformer_outputs.cross_attentions, + ) + + return forward diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py index 022e6ff5b..303766993 100644 --- a/colossalai/shardformer/policies/gpt2.py +++ b/colossalai/shardformer/policies/gpt2.py @@ -5,7 +5,12 @@ from torch import Tensor, nn import colossalai.shardformer.layer as col_nn -from ..modeling.gpt2 import GPT2PipelineForwards, get_gpt2_flash_attention_forward, gpt2_sequence_parallel_forward_fn +from ..modeling.gpt2 import ( + GPT2PipelineForwards, + get_gpt2_flash_attention_forward, + get_lm_forward_with_dist_cross_entropy, + gpt2_sequence_parallel_forward_fn, +) from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription __all__ = [ @@ -87,9 +92,7 @@ class GPT2Policy(Policy): SubModuleReplacementDescription( suffix="mlp.c_proj", target_module=col_nn.GPT2FusedLinearConv1D_Row, - kwargs={ - "seq_parallel": use_sequence_parallel, - }, + kwargs={"seq_parallel": use_sequence_parallel}, ), SubModuleReplacementDescription( suffix="attn.attn_dropout", @@ -167,15 +170,35 @@ class GPT2Policy(Policy): stage_manager = self.pipeline_stage_manager held_layers = [] - layers_per_stage = self.distribute_layers(len(module.h), stage_manager.num_stages) - if stage_manager.is_first_stage(): - held_layers.append(module.wte) - held_layers.append(module.wpe) - held_layers.append(module.drop) - start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) - held_layers.extend(module.h[start_idx:end_idx]) - if stage_manager.is_last_stage(): - held_layers.append(module.ln_f) + if stage_manager.is_interleave: + assert stage_manager.num_model_chunks is not None + layers_per_stage = self.distribute_layers( + len(module.h), stage_manager.num_stages * stage_manager.num_model_chunks + ) + stage_indices = Policy.get_stage_index( + layers_per_stage, + stage_manager.stage, + num_model_chunks=stage_manager.num_model_chunks, + num_stages=stage_manager.num_stages, + ) + if stage_manager.is_first_stage(ignore_chunk=True): + held_layers.append(module.wte) + held_layers.append(module.wpe) + held_layers.append(module.drop) + for start_idx, end_idx in stage_indices: + held_layers.extend(module.h[start_idx:end_idx]) + if stage_manager.is_last_stage(ignore_chunk=True): + held_layers.append(module.ln_f) + else: + layers_per_stage = self.distribute_layers(len(module.h), stage_manager.num_stages) + if stage_manager.is_first_stage(): + held_layers.append(module.wte) + held_layers.append(module.wpe) + held_layers.append(module.drop) + start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) + held_layers.extend(module.h[start_idx:end_idx]) + if stage_manager.is_last_stage(): + held_layers.append(module.ln_f) return held_layers def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None: @@ -189,13 +212,27 @@ class GPT2Policy(Policy): else: module = self.model.transformer - layers_per_stage = Policy.distribute_layers(len(module.h), stage_manager.num_stages) - stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) - method_replacement = { - "forward": partial( - new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config + if stage_manager.is_interleave: + layers_per_stage = self.distribute_layers( + len(module.h), stage_manager.num_stages * stage_manager.num_model_chunks + ) + stage_manager.stage_indices = Policy.get_stage_index( + layers_per_stage, + stage_manager.stage, + num_model_chunks=stage_manager.num_model_chunks, + num_stages=stage_manager.num_stages, ) - } + method_replacement = { + "forward": partial(new_forward, stage_manager=stage_manager, shard_config=self.shard_config) + } + else: + layers_per_stage = Policy.distribute_layers(len(module.h), stage_manager.num_stages) + stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) + method_replacement = { + "forward": partial( + new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config + ) + } self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls) @@ -232,9 +269,10 @@ class GPT2LMHeadModelPolicy(GPT2Policy): GPT2LMHeadModel: ModulePolicyDescription( sub_module_replacement=[ SubModuleReplacementDescription( - suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True} + suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": False} ) - ] + ], + method_replacement={"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)}, ) } module_policy.update(addon_module) @@ -249,7 +287,7 @@ class GPT2LMHeadModelPolicy(GPT2Policy): def get_held_layers(self) -> List[nn.Module]: held_layers = super().get_held_layers() - if self.pipeline_stage_manager.is_last_stage(): + if self.pipeline_stage_manager.is_last_stage(ignore_chunk=True): held_layers.append(self.model.lm_head) return held_layers diff --git a/colossalai/shardformer/shard/shardformer.py b/colossalai/shardformer/shard/shardformer.py index 7a0d75bf2..b132f47fd 100644 --- a/colossalai/shardformer/shard/shardformer.py +++ b/colossalai/shardformer/shard/shardformer.py @@ -1,3 +1,4 @@ +import os from typing import Dict, List, Tuple import torch.nn as nn @@ -9,6 +10,9 @@ from ..policies.base_policy import Policy from .shard_config import ShardConfig from .sharder import ModelSharder +# set CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that when communication and computation overlap, the order of core scheduling is correct +os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" + class ShardFormer: """ diff --git a/examples/__init__.py b/examples/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/examples/language/__init__.py b/examples/language/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/examples/language/llama2/data_utils.py b/examples/language/data_utils.py similarity index 99% rename from examples/language/llama2/data_utils.py rename to examples/language/data_utils.py index 6b9e8ef28..ec849ef9d 100644 --- a/examples/language/llama2/data_utils.py +++ b/examples/language/data_utils.py @@ -121,4 +121,4 @@ class RandomDataset(Dataset): "input_ids": self.input_ids[idx], "attention_mask": self.attention_mask[idx], "labels": self.input_ids[idx], - } + } \ No newline at end of file diff --git a/examples/language/gpt/hybridparallelism/benchmark.py b/examples/language/gpt/hybridparallelism/benchmark.py new file mode 100644 index 000000000..1315deae6 --- /dev/null +++ b/examples/language/gpt/hybridparallelism/benchmark.py @@ -0,0 +1,228 @@ +import argparse +import resource +from contextlib import nullcontext + +import torch +from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, MixedPrecision +from torch.optim import Adam +from tqdm import tqdm +from transformers.models.gpt2.configuration_gpt2 import GPT2Config +from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel + +import colossalai + +# import colossalai.utils.device as device_utils +from colossalai.booster import Booster +from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, TorchFSDPPlugin +from colossalai.cluster import DistCoordinator +from colossalai.lazy import LazyInitContext +from colossalai.utils import get_current_device +from examples.language.data_utils import RandomDataset +from examples.language.model_utils import format_numel_str, get_model_numel +from examples.language.performance_evaluator import PerformanceEvaluator + +# ============================== +# Constants +# ============================== +MODEL_CONFIGS = { + "118M": GPT2Config(activation_function="gelu"), + "338M": GPT2Config(n_embd=1024, n_head=16, n_layer=24, activation_function="gelu"), + "738M": GPT2Config(n_embd=1280, n_head=20, n_layer=36, activation_function="gelu"), + "6.21B": GPT2Config(n_embd=4096, n_head=32, n_layer=32, n_positions=4096, activation_function="gelu"), +} + + +def main(): + # ============================== + # Parse Arguments + # ============================== + parser = argparse.ArgumentParser() + parser.add_argument("-c", "--config", type=str, default="6.21B", help="Model configuration") + parser.add_argument( + "-p", + "--plugin", + choices=["gemini", "gemini_auto", "fsdp", "fsdp_cpu", "3d", "3d_cpu"], + default="gemini", + help="Choose which plugin to use", + ) + parser.add_argument("-b", "--batch_size", type=int, default=2, help="Batch size") + parser.add_argument("-s", "--num_steps", type=int, default=200, help="Number of steps to run") + parser.add_argument("-i", "--ignore_steps", type=int, default=3, help="Number of steps to ignore") + parser.add_argument("-g", "--grad_checkpoint", action="store_true", help="Use gradient checkpointing") + parser.add_argument("-l", "--max_length", type=int, default=4096, help="Max sequence length") + parser.add_argument( + "-w", "--warmup_ratio", type=float, default=0.8, help="warm up ratio of non-model data. Only for gemini-auto" + ) + parser.add_argument("-m", "--memory_limit", type=int, help="Gemini memory limit in mb") + parser.add_argument("--shard_param_frac", type=float, default=1.0, help="Shard param fraction. Only for gemini") + parser.add_argument("--offload_optim_frac", type=float, default=0.0, help="Offload optim fraction. Only for gemini") + parser.add_argument("--offload_param_frac", type=float, default=0.0, help="Offload param fraction. Only for gemini") + parser.add_argument("--tp", type=int, default=1, help="Tensor parallel size") + parser.add_argument("--extra_dp", type=int, default=1, help="Extra data parallel size, used for Gemini") + parser.add_argument("--pp", type=int, default=1, help="Pipeline parallel size") + parser.add_argument("--mbs", type=int, default=1) + parser.add_argument("--zero", type=int, default=0) + parser.add_argument("--pp_style", type=str, default="1f1b") + parser.add_argument("--num_model_chunks", type=int, default=2) + parser.add_argument("--cpu_offload", action="store_true", help="Use gradient checkpointing") + args = parser.parse_args() + + colossalai.launch_from_torch({}) + coordinator = DistCoordinator() + + def empty_init(): + pass + + # ============================== + # Initialize Booster + # ============================== + use_empty_init = True + if args.plugin == "gemini": + plugin = GeminiPlugin( + precision="bf16", + shard_param_frac=args.shard_param_frac, + offload_optim_frac=args.offload_optim_frac, + offload_param_frac=args.offload_param_frac, + tp_size=args.tp, + extra_dp_size=args.extra_dp, + ) + elif args.plugin == "gemini_auto": + plugin = GeminiPlugin( + placement_policy="auto", + precision="bf16", + warmup_non_model_data_ratio=args.warmup_ratio, + tp_size=args.tp, + extra_dp_size=args.extra_dp, + ) + elif args.plugin == "fsdp": + if use_empty_init: + plugin = TorchFSDPPlugin( + mixed_precision=MixedPrecision( + param_dtype=torch.float16, reduce_dtype=torch.float16, buffer_dtype=torch.float16 + ), + param_init_fn=empty_init(), + ) + else: + plugin = TorchFSDPPlugin( + mixed_precision=MixedPrecision( + param_dtype=torch.float16, reduce_dtype=torch.float16, buffer_dtype=torch.float16 + ) + ) + elif args.plugin == "fsdp_cpu": + if use_empty_init: + plugin = TorchFSDPPlugin( + mixed_precision=MixedPrecision( + param_dtype=torch.float16, reduce_dtype=torch.float16, buffer_dtype=torch.float16 + ), + cpu_offload=CPUOffload(offload_params=True), + param_init_fn=empty_init(), + ) + else: + plugin = TorchFSDPPlugin( + mixed_precision=MixedPrecision( + param_dtype=torch.float16, reduce_dtype=torch.float16, buffer_dtype=torch.float16 + ), + cpu_offload=CPUOffload(offload_params=True), + ) + elif args.plugin == "3d": + plugin = HybridParallelPlugin( + tp_size=args.tp, + pp_size=args.pp, + pp_style=args.pp_style, + zero_stage=args.zero, + num_model_chunks=args.num_model_chunks, + enable_all_optimization=True, + num_microbatches=args.mbs, + cpu_offload=args.cpu_offload, + precision="bf16", + ) + elif args.plugin == "3d_cpu": + plugin = HybridParallelPlugin( + tp_size=args.tp, + pp_size=args.pp, + zero_stage=args.zero, + cpu_offload=True, + enable_fused_normalization=torch.cuda.is_available(), + num_microbatches=args.mbs, + initial_scale=2**8, + precision="bf16", + ) + else: + raise ValueError(f"Unknown plugin {args.plugin}") + + booster = Booster(plugin=plugin) + + # ============================== + # Initialize Dataset and Dataloader + # ============================== + dp_size = plugin.dp_size if isinstance(plugin, HybridParallelPlugin) else coordinator.world_size + + config = MODEL_CONFIGS[args.config] + dataset = RandomDataset( + num_samples=args.batch_size * args.num_steps * dp_size, max_length=args.max_length, vocab_size=config.vocab_size + ) + dataloader = plugin.prepare_dataloader(dataset, batch_size=args.batch_size, shuffle=True, drop_last=True) + + # ============================== + # Initialize Model and Optimizer + # ============================== + init_ctx = ( + LazyInitContext(default_device=get_current_device()) + if isinstance(plugin, (GeminiPlugin, HybridParallelPlugin)) + else nullcontext() + ) + + with init_ctx: + model = GPT2LMHeadModel(config) + + if args.grad_checkpoint: + model.gradient_checkpointing_enable() + + model_numel = get_model_numel(model) + coordinator.print_on_master(f"Model params: {format_numel_str(model_numel)}") + performance_evaluator = PerformanceEvaluator( + model_numel, + model.config.n_layer, + model.config.n_embd, + model.config.vocab_size, + args.grad_checkpoint, + args.ignore_steps, + dp_world_size=dp_size, + ) + + optimizer = Adam(model.parameters()) + torch.set_default_dtype(torch.bfloat16) + model, optimizer, _, dataloader, _ = booster.boost(model, optimizer, dataloader=dataloader) + torch.set_default_dtype(torch.float) + coordinator.print_on_master(f"Booster init max CUDA memory: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB") + coordinator.print_on_master( + f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss/1024:.2f} MB" + ) + + if isinstance(plugin, HybridParallelPlugin) and args.pp > 1: + data_iter = iter(dataloader) + for step in tqdm(range(len(dataloader)), desc="Step", disable=not coordinator.is_master()): + performance_evaluator.on_step_start(step) + booster.execute_pipeline( + data_iter, model, criterion=lambda outputs, inputs: outputs[0], optimizer=optimizer, return_loss=False + ) + optimizer.step() + optimizer.zero_grad() + performance_evaluator.on_step_end(input_ids=torch.empty(args.batch_size, args.max_length)) + else: + for step, batch in enumerate(tqdm(dataloader, desc="Step", disable=not coordinator.is_master())): + performance_evaluator.on_step_start(step) + outputs = model(**batch) + loss = outputs[0] + booster.backward(loss, optimizer) + optimizer.step() + optimizer.zero_grad() + performance_evaluator.on_step_end(**batch) + coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB") + + performance_evaluator.on_fit_end() + coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB") + + +if __name__ == "__main__": + main() diff --git a/examples/language/llama2/benchmark.py b/examples/language/llama2/benchmark.py index 54b023f64..832465490 100644 --- a/examples/language/llama2/benchmark.py +++ b/examples/language/llama2/benchmark.py @@ -19,6 +19,9 @@ from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, TorchF from colossalai.cluster import DistCoordinator from colossalai.lazy import LazyInitContext from colossalai.nn.optimizer import HybridAdam +from examples.language.data_utils import RandomDataset +from examples.language.model_utils import format_numel_str, get_model_numel +from examples.language.performance_evaluator import PerformanceEvaluator # ============================== # Constants diff --git a/examples/language/llama2/model_utils.py b/examples/language/model_utils.py similarity index 100% rename from examples/language/llama2/model_utils.py rename to examples/language/model_utils.py diff --git a/examples/language/llama2/performance_evaluator.py b/examples/language/performance_evaluator.py similarity index 100% rename from examples/language/llama2/performance_evaluator.py rename to examples/language/performance_evaluator.py diff --git a/tests/kit/model_zoo/registry.py b/tests/kit/model_zoo/registry.py index a16b16ad6..fce81ab52 100644 --- a/tests/kit/model_zoo/registry.py +++ b/tests/kit/model_zoo/registry.py @@ -102,4 +102,4 @@ class ModelZooRegistry(dict): return new_dict -model_zoo = ModelZooRegistry() +model_zoo = ModelZooRegistry() \ No newline at end of file diff --git a/tests/test_booster/test_plugin/test_3d_plugin.py b/tests/test_booster/test_plugin/test_3d_plugin.py index d629e769d..285c4866c 100644 --- a/tests/test_booster/test_plugin/test_3d_plugin.py +++ b/tests/test_booster/test_plugin/test_3d_plugin.py @@ -276,4 +276,4 @@ def test_gemini_plugin(early_stop: bool = True): if __name__ == "__main__": - test_gemini_plugin(early_stop=False) + test_gemini_plugin(early_stop=False) \ No newline at end of file diff --git a/tests/test_booster/test_plugin/test_gemini_plugin.py b/tests/test_booster/test_plugin/test_gemini_plugin.py index 17dfa3a18..0f72d2bcd 100644 --- a/tests/test_booster/test_plugin/test_gemini_plugin.py +++ b/tests/test_booster/test_plugin/test_gemini_plugin.py @@ -185,4 +185,4 @@ def test_gemini_plugin_3d(early_stop: bool = True): if __name__ == "__main__": - test_gemini_plugin(early_stop=False) + test_gemini_plugin(early_stop=False) \ No newline at end of file diff --git a/tests/test_checkpoint_io/test_gemini_checkpoint_io.py b/tests/test_checkpoint_io/test_gemini_checkpoint_io.py index 61cac1d83..daddf6dc7 100644 --- a/tests/test_checkpoint_io/test_gemini_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_gemini_checkpoint_io.py @@ -186,4 +186,4 @@ def test_gemini_ckpIO_3d(): if __name__ == "__main__": - test_gemini_ckpIO() + test_gemini_ckpIO() \ No newline at end of file diff --git a/tests/test_lazy/test_models.py b/tests/test_lazy/test_models.py index d0c4cd0a7..aeca5f21d 100644 --- a/tests/test_lazy/test_models.py +++ b/tests/test_lazy/test_models.py @@ -24,4 +24,4 @@ def test_torchvision_models_lazy_init(subset, default_device): if __name__ == "__main__": - test_torchvision_models_lazy_init("transformers", "cpu") + test_torchvision_models_lazy_init("transformers", "cpu") \ No newline at end of file diff --git a/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py b/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py index 10ffdcd71..e056860ed 100644 --- a/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py +++ b/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py @@ -1,3 +1,4 @@ +import os from contextlib import nullcontext import torch @@ -11,8 +12,10 @@ from colossalai.shardformer.layer import GPT2FusedLinearConv1D_Col, GPT2FusedLin from colossalai.shardformer.layer.qkv_fused_linear import split_fused_qkv_in_gpt2_style from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn - # This code is copied from https://github.com/huggingface/transformers +os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" + + class Conv1D(nn.Module): """ 1D-convolutional layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2). diff --git a/tests/test_shardformer/test_layer/test_linear_1d.py b/tests/test_shardformer/test_layer/test_linear_1d.py index 5bacf1865..defa4afb9 100644 --- a/tests/test_shardformer/test_layer/test_linear_1d.py +++ b/tests/test_shardformer/test_layer/test_linear_1d.py @@ -1,3 +1,4 @@ +import os from contextlib import nullcontext import torch @@ -11,6 +12,8 @@ from colossalai.shardformer.layer import Linear1D_Col, Linear1D_Row from colossalai.tensor.d_tensor import is_distributed_tensor from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn +os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" + def check_linear_1d_col(lazy_init: bool, seq_parallel: bool, overlap: bool): ctx = LazyInitContext() if lazy_init else nullcontext() diff --git a/tests/test_shardformer/test_layer/test_qkv_fused_linear_1d.py b/tests/test_shardformer/test_layer/test_qkv_fused_linear_1d.py index b02d58181..5e996d2ba 100644 --- a/tests/test_shardformer/test_layer/test_qkv_fused_linear_1d.py +++ b/tests/test_shardformer/test_layer/test_qkv_fused_linear_1d.py @@ -1,3 +1,4 @@ +import os from contextlib import nullcontext import torch @@ -11,8 +12,10 @@ from colossalai.shardformer.layer import GPT2FusedLinearConv1D_Col, GPT2FusedLin from colossalai.shardformer.layer.qkv_fused_linear import split_fused_qkv_in_gpt2_style from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn - # This code is copied from https://github.com/huggingface/transformers +os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" + + class Conv1D(nn.Module): """ 1D-convolutional layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2).