import os from contextlib import nullcontext from functools import partial from time import time import psutil import torch import torch.nn as nn from commons.model_zoo import model_builder from commons.utils import get_data, get_profile_context, get_tflops, get_time_stamp from packaging import version from torch.nn.parallel import DistributedDataParallel as DDP import colossalai from colossalai.booster import Booster from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin from colossalai.lazy import LazyInitContext from colossalai.logging import disable_existing_loggers, get_dist_logger from colossalai.nn.optimizer import HybridAdam from colossalai.utils import get_current_device CAI_VERSION = colossalai.__version__ def parse_args(): parser = colossalai.get_default_parser() parser.add_argument( "--distplan", type=str, default='CAI_Gemini', help="The distributed plan [colossalai, zero1, zero2, torch_ddp, torch_zero].", ) parser.add_argument( "--batch_size", type=int, default=8, help="batch size per DP group of training.", ) parser.add_argument( "--model_type", type=str, default="gpt2_medium", help="model model scale", ) parser.add_argument( "--train_step", type=int, default=10, help="training iterations for test", ) args = parser.parse_args() return args class GPTLMLoss(nn.Module): def __init__(self): super().__init__() self.loss_fn = nn.CrossEntropyLoss() def forward(self, logits, labels): shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() # Flatten the tokens return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) def get_cpu_mem(): return psutil.Process().memory_info().rss / 1024**2 def get_gpu_mem(): return torch.cuda.memory_allocated() / 1024**2 def get_mem_info(prefix=''): return f'{prefix}GPU memory usage: {get_gpu_mem():.2f} MB, CPU memory usage: {get_cpu_mem():.2f} MB' def get_model_size(model: nn.Module): total_numel = 0 for module in model.modules(): for p in module.parameters(recurse=False): total_numel += p.numel() return total_numel def model_size_formatter(numel: int) -> str: GB_SIZE = 10**9 MB_SIZE = 10**6 KB_SIZE = 10**3 if numel >= GB_SIZE: return f'{numel / GB_SIZE:.1f}B' elif numel >= MB_SIZE: return f'{numel / MB_SIZE:.1f}M' elif numel >= KB_SIZE: return f'{numel / KB_SIZE:.1f}K' else: return str(numel) def set_cpu_maximum_parallelism(): conf_str = torch.__config__.parallel_info() inter_str = conf_str.split("hardware_concurrency() : ")[1] max_concurrency = inter_str.split('\n')[0] os.environ["OMP_NUM_THREADS"] = max_concurrency print(f"environmental variable OMP_NUM_THREADS is set to {max_concurrency}.") def main(): # version check # this example is supposed to work for versions greater than 0.2.0 assert version.parse(CAI_VERSION) >= version.parse("0.2.0") set_cpu_maximum_parallelism() args = parse_args() # if args.distplan not in ["colossalai", "torch_ddp", "torch_zero", "zero1", "zero2"]: if args.distplan not in ["CAI_ZeRO1", "CAI_ZeRO2", "CAI_Gemini", "Pytorch_DDP", "Pytorch_ZeRO"]: raise TypeError(f"{args.distplan} is error") # batch size per DP degree BATCH_SIZE = args.batch_size SEQ_LEN = 1024 VOCAB_SIZE = 50257 NUM_STEPS = args.train_step WARMUP_STEPS = 1 assert WARMUP_STEPS < NUM_STEPS, "warmup steps should smaller than the total steps" assert (NUM_STEPS - WARMUP_STEPS) % 2 == 1, "the number of valid steps should be odd to take the median" PROF_FLAG = False # The flag of profiling, False by default disable_existing_loggers() colossalai.launch_from_torch(config={}) logger = get_dist_logger() logger.info(f"{args.model_type}, {args.distplan}, batch size {BATCH_SIZE}", ranks=[0]) # build criterion criterion = GPTLMLoss() torch.manual_seed(123) if args.distplan.startswith("CAI"): ctx = LazyInitContext(default_device=get_current_device()) if args.distplan == "CAI_Gemini" else nullcontext() # build GPT model with ctx: model = model_builder(args.model_type)(checkpoint=True) # assign running configurations if args.distplan == "CAI_ZeRO1": zero_stage = 1 elif args.distplan == "CAI_ZeRO2": zero_stage = 2 elif args.distplan == "CAI_Gemini": zero_stage = 3 else: raise RuntimeError plugin = None if args.distplan.startswith("CAI_ZeRO"): plugin = LowLevelZeroPlugin(stage=zero_stage, reduce_bucket_size_in_m=12, overlap_communication=True, verbose=True) elif args.distplan == "CAI_Gemini": plugin = GeminiPlugin(search_range_m=128, hidden_dim=model.config.n_embd) else: raise RuntimeError # build a highly optimized gpu/cpu optimizer optimizer = HybridAdam(model.parameters(), lr=1e-3) logger.info(get_mem_info(prefix='After init optim, '), ranks=[0]) elif args.distplan.startswith("Pytorch"): assert args.tp_degree == 1, "The degree of TP should be 1 for DDP examples." model = model_builder(args.model_type)(checkpoint=True).cuda() plugin = TorchDDPPlugin() if args.distplan.endswith("DDP"): optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) elif args.distplan.endswith("ZeRO"): from torch.distributed.optim import ZeroRedundancyOptimizer optimizer = ZeroRedundancyOptimizer(model.parameters(), optimizer_class=torch.optim.Adam, lr=1e-3) else: raise RuntimeError # wrap your model and optimizer booster = Booster(plugin=plugin) model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion) # model is shared after TP numel = get_model_size(model) logger.info(f"the size of testing model size is {model_size_formatter(numel)}.") logger.info(get_mem_info(prefix='After init model, '), ranks=[0]) # Tflops_per_GPU = global_batch * global_numel * seq_len * 8 / #gpu # = (batch_per_DP_group * dp_degree) * (numel * tp_degree) * seq_len * 8 / (tp_degree * dp_degree) # = batch_per_DP_group * numel * seq_len * 8 get_tflops_func = partial(get_tflops, numel, BATCH_SIZE, SEQ_LEN) torch.cuda.synchronize() model.train() tflops_list = [] def train_step(): # we just use randomly generated data here input_ids, attn_mask = get_data(BATCH_SIZE, SEQ_LEN, VOCAB_SIZE) optimizer.zero_grad() start = time() outputs = model(input_ids, attn_mask) loss = criterion(outputs, input_ids) torch.cuda.synchronize() fwd_end = time() fwd_time = fwd_end - start logger.info(get_mem_info(prefix=f'[{n + 1}/{NUM_STEPS}] Forward '), ranks=[0]) booster.backward(loss, optimizer) torch.cuda.synchronize() bwd_end = time() bwd_time = bwd_end - fwd_end logger.info(get_mem_info(prefix=f'[{n + 1}/{NUM_STEPS}] Backward '), ranks=[0]) optimizer.step() torch.cuda.synchronize() optim_time = time() - bwd_end step_time = time() - start logger.info(get_mem_info(prefix=f'[{n + 1}/{NUM_STEPS}] Optimizer step '), ranks=[0]) step_tflops = get_tflops_func(step_time) logger.info( f"[{n + 1}/{NUM_STEPS}] Loss:{loss.item():.3f}, Step time: {step_time:.3f}s, TFLOPS: {get_tflops_func(step_time):.3f}, FWD time: {fwd_time:.3f}s, BWD time: {bwd_time:.3f}s, OPTIM time: {optim_time:.3f}s", ranks=[0], ) if n >= WARMUP_STEPS: tflops_list.append(step_tflops) demo_profiler = get_profile_context(PROF_FLAG, WARMUP_STEPS, NUM_STEPS - WARMUP_STEPS, save_dir=f"profile/{get_time_stamp()}-demo") with demo_profiler as prof: for n in range(NUM_STEPS): train_step() prof.step() tflops_list.sort() median_index = ((NUM_STEPS - WARMUP_STEPS) >> 1) + WARMUP_STEPS logger.info(f"Median TFLOPS is {tflops_list[median_index]:.3f}") torch.cuda.synchronize() if __name__ == '__main__': main()