diff --git a/examples/language/palm/train.py b/examples/language/palm/train.py index 7c080b7f3..6725c07df 100644 --- a/examples/language/palm/train.py +++ b/examples/language/palm/train.py @@ -1,9 +1,11 @@ import gzip import random - +from time import time +from functools import partial import numpy as np import torch import torch.optim as optim +import torch.nn as nn import tqdm from packaging import version from palm_pytorch import PaLM @@ -21,7 +23,8 @@ from colossalai.utils.model.colo_init_context import ColoInitContext # constants -NUM_BATCHES = int(1000) +NUM_BATCHES = int(100) +WARMUP_BATCHES = 1 GRADIENT_ACCUMULATE_EVERY = 1 LEARNING_RATE = 2e-4 VALIDATE_EVERY = 100 @@ -76,10 +79,18 @@ def cycle(loader): def decode_token(token): return str(chr(max(32, token))) +def get_tflops(model_numel, batch_size, seq_len, step_time): + return model_numel * batch_size * seq_len * 8 / 1e12 / (step_time + 1e-12) def decode_tokens(tokens): return "".join(list(map(decode_token, tokens))) +def get_model_size(model: nn.Module): + total_numel = 0 + for module in model.modules(): + for p in module.parameters(recurse=False): + total_numel += p.numel() + return total_numel # Gemini + ZeRO DDP def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placememt_policy: str = "auto"): @@ -143,7 +154,6 @@ def tensor_parallelize(model: torch.nn.Module, pg: ProcessGroup): split_param_row_tp1d(param, pg) # row slice else: param.set_dist_spec(ReplicaSpec()) - param.visited = True @@ -152,6 +162,7 @@ if args.distplan not in ["colossalai", "pytorch"]: raise TypeError(f"{args.distplan} is error") disable_existing_loggers() colossalai.launch_from_torch(config={}) +logger = get_dist_logger() with gzip.open("./data/enwik8.gz") as file: X = np.fromstring(file.read(int(95e6)), dtype=np.uint8) @@ -188,7 +199,7 @@ if args.distplan == "colossalai": ctx = ColoInitContext(device='cpu', default_dist_spec=default_dist_spec, default_pg=default_pg) with ctx: - model = PaLM(num_tokens=256, dim=512, depth=8) + model = PaLM(num_tokens=50304, dim=4096, depth=64) model = AutoregressiveWrapper(model, max_seq_len=SEQ_LEN) pg = default_pg @@ -205,25 +216,42 @@ else: model.cuda() optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE) - + # model is shared after TP +numel = get_model_size(model) +get_tflops_func = partial(get_tflops, numel, args.batch_size, SEQ_LEN) # training model.train() - +tflops_list = [] for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10.0, desc="training"): if args.distplan == "colossalai": optimizer.zero_grad() - + start = time() loss = model(next(train_loader)) + fwd_end = time() + fwd_time = fwd_end - start # loss.backward() optimizer.backward(loss) + bwd_end = time() + bwd_time = bwd_end - fwd_end - print(f"training loss: {loss.item()}") + # print(f"training loss: {loss.item()}") torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5) # optim.step() # optim.zero_grad() optimizer.step() + optim_time = time() - bwd_end + step_time = time() - start + + step_tflops = get_tflops_func(step_time) + logger.info( + f"[{i + 1}/{NUM_BATCHES}] Loss:{loss.item():.3f}, Step time: {step_time:.3f}s, TFLOPS: {get_tflops_func(step_time):.3f}, FWD time: {fwd_time:.3f}s, BWD time: {bwd_time:.3f}s, OPTIM time: {optim_time:.3f}s", + ranks=[0], + ) + if i >= WARMUP_BATCHES: + tflops_list.append(step_tflops) + else: for __ in range(GRADIENT_ACCUMULATE_EVERY): loss = model(next(train_loader)) @@ -233,6 +261,11 @@ for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10.0, desc="training"): torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5) optim.step() optim.zero_grad() + +tflops_list.sort() +median_index = ((NUM_BATCHES - WARMUP_BATCHES) >> 1) + WARMUP_BATCHES +logger.info(f"Median TFLOPS is {tflops_list[median_index]:.3f}") + # TODO # if i % VALIDATE_EVERY == 0: