mirror of https://github.com/hpcaitech/ColossalAI
258 lines
8.5 KiB
Python
258 lines
8.5 KiB
Python
import argparse
|
||
import os
|
||
from contextlib import nullcontext
|
||
from functools import partial
|
||
from time import time
|
||
|
||
import psutil
|
||
import torch
|
||
import torch.nn as nn
|
||
from commons.model_zoo import model_builder
|
||
from commons.utils import get_data, get_profile_context, get_tflops, get_time_stamp
|
||
from packaging import version
|
||
|
||
import colossalai
|
||
from colossalai.accelerator import get_accelerator
|
||
from colossalai.booster import Booster
|
||
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
|
||
from colossalai.lazy import LazyInitContext
|
||
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
||
from colossalai.nn.optimizer import HybridAdam
|
||
|
||
CAI_VERSION = colossalai.__version__
|
||
|
||
|
||
def parse_args():
|
||
parser = argparse.ArgumentParser()
|
||
parser.add_argument(
|
||
"--distplan",
|
||
type=str,
|
||
default="CAI_Gemini",
|
||
help="The distributed plan [colossalai, zero1, zero2, torch_ddp, torch_zero].",
|
||
)
|
||
parser.add_argument(
|
||
"--batch_size",
|
||
type=int,
|
||
default=8,
|
||
help="batch size per DP group of training.",
|
||
)
|
||
parser.add_argument(
|
||
"--model_type",
|
||
type=str,
|
||
default="gpt2_medium",
|
||
help="model model scale",
|
||
)
|
||
parser.add_argument(
|
||
"--train_step",
|
||
type=int,
|
||
default=10,
|
||
help="training iterations for test",
|
||
)
|
||
|
||
args = parser.parse_args()
|
||
return args
|
||
|
||
|
||
class GPTLMLoss(nn.Module):
|
||
def __init__(self):
|
||
super().__init__()
|
||
self.loss_fn = nn.CrossEntropyLoss()
|
||
|
||
def forward(self, logits, labels):
|
||
shift_logits = logits[..., :-1, :].contiguous()
|
||
shift_labels = labels[..., 1:].contiguous()
|
||
# Flatten the tokens
|
||
return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
||
|
||
|
||
def get_cpu_mem():
|
||
return psutil.Process().memory_info().rss / 1024**2 # 返回值是B,转换成MB
|
||
|
||
|
||
def get_gpu_mem():
|
||
return torch.cuda.memory_allocated() / 1024**2 # 转换成MB
|
||
|
||
|
||
def get_mem_info(prefix=""):
|
||
return f"{prefix}GPU memory usage: {get_gpu_mem():.2f} MB, CPU memory usage: {get_cpu_mem():.2f} MB"
|
||
|
||
|
||
def get_model_size(model: nn.Module): # 得到模型参数量
|
||
total_numel = 0
|
||
for module in model.modules():
|
||
for p in module.parameters(recurse=False):
|
||
total_numel += p.numel()
|
||
return total_numel
|
||
|
||
|
||
def model_size_formatter(numel: int) -> str:
|
||
GB_SIZE = 10**9
|
||
MB_SIZE = 10**6
|
||
KB_SIZE = 10**3
|
||
if numel >= GB_SIZE:
|
||
return f"{numel / GB_SIZE:.1f}B"
|
||
elif numel >= MB_SIZE:
|
||
return f"{numel / MB_SIZE:.1f}M"
|
||
elif numel >= KB_SIZE:
|
||
return f"{numel / KB_SIZE:.1f}K"
|
||
else:
|
||
return str(numel)
|
||
|
||
|
||
def set_cpu_maximum_parallelism():
|
||
conf_str = torch.__config__.parallel_info()
|
||
inter_str = conf_str.split("hardware_concurrency() : ")[1]
|
||
max_concurrency = inter_str.split("\n")[0]
|
||
os.environ["OMP_NUM_THREADS"] = max_concurrency
|
||
print(f"environmental variable OMP_NUM_THREADS is set to {max_concurrency}.")
|
||
|
||
|
||
def main():
|
||
# version check
|
||
# this example is supposed to work for versions greater than 0.2.0
|
||
assert version.parse(CAI_VERSION) >= version.parse("0.2.0")
|
||
|
||
set_cpu_maximum_parallelism()
|
||
args = parse_args()
|
||
|
||
# if args.distplan not in ["colossalai", "torch_ddp", "torch_zero", "zero1", "zero2"]:
|
||
if args.distplan not in ["CAI_ZeRO1", "CAI_ZeRO2", "CAI_Gemini", "Pytorch_DDP", "Pytorch_ZeRO"]:
|
||
raise TypeError(f"{args.distplan} is error")
|
||
|
||
# batch size per DP degree
|
||
BATCH_SIZE = args.batch_size
|
||
SEQ_LEN = 1024
|
||
VOCAB_SIZE = 50257
|
||
|
||
NUM_STEPS = args.train_step
|
||
|
||
WARMUP_STEPS = 1
|
||
assert WARMUP_STEPS < NUM_STEPS, "warmup steps should smaller than the total steps"
|
||
assert (NUM_STEPS - WARMUP_STEPS) % 2 == 1, "the number of valid steps should be odd to take the median"
|
||
PROF_FLAG = True # The flag of profiling, False by default
|
||
|
||
disable_existing_loggers()
|
||
colossalai.launch_from_torch()
|
||
|
||
logger = get_dist_logger()
|
||
logger.info(f"{args.model_type}, {args.distplan}, batch size {BATCH_SIZE}", ranks=[0])
|
||
|
||
# build criterion
|
||
criterion = GPTLMLoss()
|
||
torch.manual_seed(123)
|
||
if args.distplan.startswith("CAI"):
|
||
ctx = (
|
||
LazyInitContext(default_device=get_accelerator().get_current_device())
|
||
if args.distplan == "CAI_Gemini"
|
||
else nullcontext()
|
||
)
|
||
# build GPT model
|
||
with ctx:
|
||
model = model_builder(args.model_type)(checkpoint=True)
|
||
|
||
# assign running configurations
|
||
if args.distplan == "CAI_ZeRO1":
|
||
zero_stage = 1
|
||
elif args.distplan == "CAI_ZeRO2":
|
||
zero_stage = 2
|
||
elif args.distplan == "CAI_Gemini":
|
||
zero_stage = 3
|
||
else:
|
||
raise RuntimeError
|
||
|
||
plugin = None
|
||
if args.distplan.startswith("CAI_ZeRO"):
|
||
plugin = LowLevelZeroPlugin(
|
||
stage=zero_stage, reduce_bucket_size_in_m=12, overlap_communication=True, verbose=True
|
||
)
|
||
elif args.distplan == "CAI_Gemini":
|
||
plugin = GeminiPlugin(search_range_m=128, hidden_dim=model.config.n_embd, max_prefetch=1)
|
||
else:
|
||
raise RuntimeError
|
||
|
||
# build a highly optimized gpu/cpu optimizer
|
||
optimizer = HybridAdam(model.parameters(), lr=1e-3)
|
||
|
||
logger.info(get_mem_info(prefix="After init optim, "), ranks=[0])
|
||
elif args.distplan.startswith("Pytorch"):
|
||
assert args.tp_degree == 1, "The degree of TP should be 1 for DDP examples."
|
||
model = model_builder(args.model_type)(checkpoint=True).cuda()
|
||
plugin = TorchDDPPlugin()
|
||
if args.distplan.endswith("DDP"):
|
||
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
|
||
elif args.distplan.endswith("ZeRO"):
|
||
from torch.distributed.optim import ZeroRedundancyOptimizer
|
||
|
||
optimizer = ZeroRedundancyOptimizer(model.parameters(), optimizer_class=torch.optim.Adam, lr=1e-3)
|
||
|
||
else:
|
||
raise RuntimeError
|
||
# wrap your model and optimizer
|
||
booster = Booster(plugin=plugin)
|
||
model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion)
|
||
|
||
# model is shared after TP
|
||
numel = get_model_size(model)
|
||
logger.info(f"the size of testing model size is {model_size_formatter(numel)}.")
|
||
logger.info(get_mem_info(prefix="After init model, "), ranks=[0])
|
||
|
||
# Tflops_per_GPU = global_batch * global_numel * seq_len * 8 / #gpu
|
||
# = (batch_per_DP_group * dp_degree) * (numel * tp_degree) * seq_len * 8 / (tp_degree * dp_degree)
|
||
# = batch_per_DP_group * numel * seq_len * 8
|
||
get_tflops_func = partial(get_tflops, numel, BATCH_SIZE, SEQ_LEN)
|
||
|
||
torch.cuda.synchronize()
|
||
model.train()
|
||
tflops_list = []
|
||
|
||
def train_step():
|
||
# we just use randomly generated data here
|
||
input_ids, attn_mask = get_data(BATCH_SIZE, SEQ_LEN, VOCAB_SIZE)
|
||
optimizer.zero_grad()
|
||
|
||
start = time()
|
||
outputs = model(input_ids, attn_mask)
|
||
loss = criterion(outputs, input_ids)
|
||
torch.cuda.synchronize()
|
||
fwd_end = time()
|
||
fwd_time = fwd_end - start
|
||
logger.info(get_mem_info(prefix=f"[{n + 1}/{NUM_STEPS}] Forward "), ranks=[0])
|
||
booster.backward(loss, optimizer)
|
||
|
||
torch.cuda.synchronize()
|
||
bwd_end = time()
|
||
bwd_time = bwd_end - fwd_end
|
||
logger.info(get_mem_info(prefix=f"[{n + 1}/{NUM_STEPS}] Backward "), ranks=[0])
|
||
|
||
optimizer.step()
|
||
torch.cuda.synchronize()
|
||
optim_time = time() - bwd_end
|
||
step_time = time() - start
|
||
logger.info(get_mem_info(prefix=f"[{n + 1}/{NUM_STEPS}] Optimizer step "), ranks=[0])
|
||
|
||
step_tflops = get_tflops_func(step_time)
|
||
logger.info(
|
||
f"[{n + 1}/{NUM_STEPS}] Loss:{loss.item():.3f}, Step time: {step_time:.3f}s, TFLOPS: {get_tflops_func(step_time):.3f}, FWD time: {fwd_time:.3f}s, BWD time: {bwd_time:.3f}s, OPTIM time: {optim_time:.3f}s",
|
||
ranks=[0],
|
||
)
|
||
if n >= WARMUP_STEPS:
|
||
tflops_list.append(step_tflops)
|
||
|
||
demo_profiler = get_profile_context(
|
||
PROF_FLAG, WARMUP_STEPS, NUM_STEPS - WARMUP_STEPS, save_dir=f"profile/{get_time_stamp()}-demo"
|
||
)
|
||
|
||
with demo_profiler as prof:
|
||
for n in range(NUM_STEPS):
|
||
train_step()
|
||
prof.step()
|
||
|
||
tflops_list.sort()
|
||
median_index = min(((NUM_STEPS - WARMUP_STEPS) >> 1) + WARMUP_STEPS, len(tflops_list) - 1)
|
||
logger.info(f"Median TFLOPS is {tflops_list[median_index]:.3f}")
|
||
torch.cuda.synchronize()
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|