mirror of https://github.com/hpcaitech/ColossalAI
aibig-modeldata-parallelismdeep-learningdistributed-computingfoundation-modelsheterogeneous-traininghpcinferencelarge-scalemodel-parallelismpipeline-parallelism
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
243 lines
7.3 KiB
243 lines
7.3 KiB
import argparse |
|
import gzip |
|
from contextlib import nullcontext |
|
from functools import partial |
|
from time import time |
|
|
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
import torch.optim as optim |
|
import tqdm |
|
from palm_pytorch import PaLM |
|
from palm_pytorch.autoregressive_wrapper import AutoregressiveWrapper |
|
from torch.utils.data import DataLoader, Dataset |
|
|
|
import colossalai |
|
from colossalai.booster import Booster |
|
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin |
|
from colossalai.lazy import LazyInitContext |
|
from colossalai.logging import disable_existing_loggers, get_dist_logger |
|
from colossalai.nn import HybridAdam |
|
from colossalai.utils import get_current_device |
|
|
|
# constants |
|
|
|
NUM_BATCHES = int(10) |
|
WARMUP_BATCHES = 1 |
|
GRADIENT_ACCUMULATE_EVERY = 1 |
|
LEARNING_RATE = 2e-4 |
|
VALIDATE_EVERY = 100 |
|
GENERATE_EVERY = 500 |
|
GENERATE_LENGTH = 512 |
|
SEQ_LEN = 1024 |
|
|
|
|
|
def parse_args(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument( |
|
"--distplan", |
|
type=str, |
|
default="colossalai", |
|
help="The distributed plan [colossalai, pytorch].", |
|
) |
|
parser.add_argument( |
|
"--offload_optim_frac", |
|
type=float, |
|
default=1.0, |
|
help="Fraction of optimizer states to be offloaded. This is only used for gemini.", |
|
) |
|
parser.add_argument( |
|
"-p", |
|
"--plugin", |
|
type=str, |
|
default="torch_ddp", |
|
choices=["torch_ddp", "torch_ddp_fp16", "gemini", "low_level_zero"], |
|
help="plugin to use", |
|
) |
|
parser.add_argument( |
|
"--batch_size", |
|
type=int, |
|
default=8, |
|
help="batch size per DP group of training.", |
|
) |
|
parser.add_argument( |
|
"--dummy_data", |
|
type=bool, |
|
default=False, |
|
help="use dummy dataset.", |
|
) |
|
args = parser.parse_args() |
|
return args |
|
|
|
|
|
# helpers |
|
def cycle(loader): |
|
while True: |
|
for data in loader: |
|
yield data |
|
|
|
|
|
def decode_token(token): |
|
return str(chr(max(32, token))) |
|
|
|
|
|
def get_tflops(model_numel, batch_size, seq_len, step_time): |
|
return model_numel * batch_size * seq_len * 8 / 1e12 / (step_time + 1e-12) |
|
|
|
|
|
def decode_tokens(tokens): |
|
return "".join(list(map(decode_token, tokens))) |
|
|
|
|
|
def get_model_size(model: nn.Module): |
|
total_numel = 0 |
|
for module in model.modules(): |
|
for p in module.parameters(recurse=False): |
|
total_numel += p.numel() |
|
return total_numel |
|
|
|
|
|
args = parse_args() |
|
if args.distplan not in ["colossalai", "pytorch"]: |
|
raise TypeError(f"{args.distplan} is error") |
|
disable_existing_loggers() |
|
colossalai.launch_from_torch(config={}) |
|
logger = get_dist_logger() |
|
|
|
|
|
def generate_dataset(dummy_data: bool = False): |
|
if not dummy_data: |
|
with gzip.open("./data/enwik8.gz") as file: |
|
X = np.fromstring(file.read(int(95e6)), dtype=np.uint8) |
|
trX, vaX = np.split(X, [int(90e6)]) |
|
data_train, data_val = torch.from_numpy(trX), torch.from_numpy(vaX) |
|
# print(f"data_train {data_train.shape} {data_train.dtype} {max(data_train)} {min(data_train)}") |
|
# print(f"data_val {data_val.shape} {data_val.dtype} {max(data_val)} {min(data_val)}") |
|
return data_train, data_val |
|
else: |
|
return torch.randint(0, 100, (90000000,)), torch.randint(0, 100, (5000000,)) |
|
|
|
|
|
data_train, data_val = generate_dataset(args.dummy_data) |
|
|
|
print("generate dataset ready!") |
|
|
|
|
|
class TextSamplerDataset(Dataset): |
|
def __init__(self, data, seq_len): |
|
super().__init__() |
|
self.data = data |
|
self.seq_len = seq_len |
|
|
|
def __getitem__(self, index): |
|
rand_start = torch.randint(0, self.data.size(0) - self.seq_len, (1,)) |
|
full_seq = self.data[rand_start : rand_start + self.seq_len + 1].long() |
|
return full_seq.cuda() |
|
|
|
def __len__(self): |
|
return self.data.size(0) // self.seq_len |
|
|
|
|
|
train_dataset = TextSamplerDataset(data_train, SEQ_LEN) |
|
val_dataset = TextSamplerDataset(data_val, SEQ_LEN) |
|
train_loader = cycle(DataLoader(train_dataset, batch_size=args.batch_size)) |
|
val_loader = cycle(DataLoader(val_dataset, batch_size=args.batch_size)) |
|
|
|
if args.distplan == "colossalai": |
|
# instantiate GPT-like decoder model |
|
|
|
booster_kwargs = {} |
|
if args.plugin == "torch_ddp_fp16": |
|
booster_kwargs["mixed_precision"] = "fp16" |
|
if args.plugin.startswith("torch_ddp"): |
|
plugin = TorchDDPPlugin() |
|
elif args.plugin == "gemini": |
|
plugin = GeminiPlugin(offload_optim_frac=args.offload_optim_frac, initial_scale=2**5) |
|
elif args.plugin == "low_level_zero": |
|
plugin = LowLevelZeroPlugin(initial_scale=2**5) |
|
logger.info(f"plugin: {plugin}") |
|
booster = Booster(plugin=plugin, **booster_kwargs) |
|
|
|
ctx = LazyInitContext(default_device=get_current_device()) if args.plugin == "gemini" else nullcontext() |
|
|
|
with ctx: |
|
model = PaLM(num_tokens=50304, dim=4096, depth=64) |
|
model = AutoregressiveWrapper(model, max_seq_len=SEQ_LEN) |
|
|
|
# optimizer |
|
|
|
optimizer = HybridAdam(model.parameters(), lr=LEARNING_RATE, initial_scale=2**5) |
|
model, optimizer, _, _, _ = booster.boost(model, optimizer) |
|
|
|
else: |
|
model = PaLM(num_tokens=256, dim=512, depth=8) |
|
model = AutoregressiveWrapper(model, max_seq_len=2048) |
|
model.cuda() |
|
optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE) |
|
|
|
# model is shared after TP |
|
numel = get_model_size(model) |
|
get_tflops_func = partial(get_tflops, numel, args.batch_size, SEQ_LEN) |
|
|
|
# training |
|
model.train() |
|
tflops_list = [] |
|
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10.0, desc="training"): |
|
if args.distplan == "colossalai": |
|
optimizer.zero_grad() |
|
start = time() |
|
loss = model(next(train_loader)) |
|
fwd_end = time() |
|
fwd_time = fwd_end - start |
|
# loss.backward() |
|
optimizer.backward(loss) |
|
bwd_end = time() |
|
bwd_time = bwd_end - fwd_end |
|
|
|
# print(f"training loss: {loss.item()}") |
|
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5) |
|
# optim.step() |
|
# optim.zero_grad() |
|
optimizer.step() |
|
optim_time = time() - bwd_end |
|
step_time = time() - start |
|
|
|
step_tflops = get_tflops_func(step_time) |
|
logger.info( |
|
f"[{i + 1}/{NUM_BATCHES}] Loss:{loss.item():.3f}, Step time: {step_time:.3f}s, TFLOPS: {get_tflops_func(step_time):.3f}, FWD time: {fwd_time:.3f}s, BWD time: {bwd_time:.3f}s, OPTIM time: {optim_time:.3f}s", |
|
ranks=[0], |
|
) |
|
if i >= WARMUP_BATCHES: |
|
tflops_list.append(step_tflops) |
|
|
|
else: |
|
for __ in range(GRADIENT_ACCUMULATE_EVERY): |
|
loss = model(next(train_loader)) |
|
loss.backward() |
|
|
|
print(f"training loss: {loss.item()}") |
|
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5) |
|
optim.step() |
|
optim.zero_grad() |
|
|
|
tflops_list.sort() |
|
median_index = ((NUM_BATCHES - WARMUP_BATCHES) >> 1) + WARMUP_BATCHES |
|
logger.info(f"Median TFLOPS is {tflops_list[median_index]:.3f}") |
|
|
|
# TODO |
|
# if i % VALIDATE_EVERY == 0: |
|
# model.eval() |
|
# with torch.no_grad(): |
|
# loss = model(next(val_loader)) |
|
# print(f"validation loss: {loss.item()}") |
|
|
|
# if i % GENERATE_EVERY == 0: |
|
# model.eval() |
|
# inp = random.choice(val_dataset)[:-1] |
|
# prime = decode_tokens(inp) |
|
# print(f"%s \n\n %s", (prime, "*" * 100)) |
|
|
|
# sample = model.generate(inp[None, ...], GENERATE_LENGTH) |
|
# output_str = decode_tokens(sample[0]) |
|
# print(output_str)
|
|
|