[examples] adding tflops to PaLM (#2365)

pull/2424/head
ZijianYY 2023-01-10 16:18:56 +08:00 committed by GitHub
parent 93f62dd152
commit fe0f7970a2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 41 additions and 8 deletions

View File

@ -1,9 +1,11 @@
import gzip import gzip
import random import random
from time import time
from functools import partial
import numpy as np import numpy as np
import torch import torch
import torch.optim as optim import torch.optim as optim
import torch.nn as nn
import tqdm import tqdm
from packaging import version from packaging import version
from palm_pytorch import PaLM from palm_pytorch import PaLM
@ -21,7 +23,8 @@ from colossalai.utils.model.colo_init_context import ColoInitContext
# constants # constants
NUM_BATCHES = int(1000) NUM_BATCHES = int(100)
WARMUP_BATCHES = 1
GRADIENT_ACCUMULATE_EVERY = 1 GRADIENT_ACCUMULATE_EVERY = 1
LEARNING_RATE = 2e-4 LEARNING_RATE = 2e-4
VALIDATE_EVERY = 100 VALIDATE_EVERY = 100
@ -76,10 +79,18 @@ def cycle(loader):
def decode_token(token): def decode_token(token):
return str(chr(max(32, token))) return str(chr(max(32, token)))
def get_tflops(model_numel, batch_size, seq_len, step_time):
return model_numel * batch_size * seq_len * 8 / 1e12 / (step_time + 1e-12)
def decode_tokens(tokens): def decode_tokens(tokens):
return "".join(list(map(decode_token, tokens))) return "".join(list(map(decode_token, tokens)))
def get_model_size(model: nn.Module):
total_numel = 0
for module in model.modules():
for p in module.parameters(recurse=False):
total_numel += p.numel()
return total_numel
# Gemini + ZeRO DDP # Gemini + ZeRO DDP
def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placememt_policy: str = "auto"): def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placememt_policy: str = "auto"):
@ -143,7 +154,6 @@ def tensor_parallelize(model: torch.nn.Module, pg: ProcessGroup):
split_param_row_tp1d(param, pg) # row slice split_param_row_tp1d(param, pg) # row slice
else: else:
param.set_dist_spec(ReplicaSpec()) param.set_dist_spec(ReplicaSpec())
param.visited = True param.visited = True
@ -152,6 +162,7 @@ if args.distplan not in ["colossalai", "pytorch"]:
raise TypeError(f"{args.distplan} is error") raise TypeError(f"{args.distplan} is error")
disable_existing_loggers() disable_existing_loggers()
colossalai.launch_from_torch(config={}) colossalai.launch_from_torch(config={})
logger = get_dist_logger()
with gzip.open("./data/enwik8.gz") as file: with gzip.open("./data/enwik8.gz") as file:
X = np.fromstring(file.read(int(95e6)), dtype=np.uint8) X = np.fromstring(file.read(int(95e6)), dtype=np.uint8)
@ -188,7 +199,7 @@ if args.distplan == "colossalai":
ctx = ColoInitContext(device='cpu', default_dist_spec=default_dist_spec, default_pg=default_pg) ctx = ColoInitContext(device='cpu', default_dist_spec=default_dist_spec, default_pg=default_pg)
with ctx: with ctx:
model = PaLM(num_tokens=256, dim=512, depth=8) model = PaLM(num_tokens=50304, dim=4096, depth=64)
model = AutoregressiveWrapper(model, max_seq_len=SEQ_LEN) model = AutoregressiveWrapper(model, max_seq_len=SEQ_LEN)
pg = default_pg pg = default_pg
@ -205,25 +216,42 @@ else:
model.cuda() model.cuda()
optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE) optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
# model is shared after TP
numel = get_model_size(model)
get_tflops_func = partial(get_tflops, numel, args.batch_size, SEQ_LEN)
# training # training
model.train() model.train()
tflops_list = []
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10.0, desc="training"): for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10.0, desc="training"):
if args.distplan == "colossalai": if args.distplan == "colossalai":
optimizer.zero_grad() optimizer.zero_grad()
start = time()
loss = model(next(train_loader)) loss = model(next(train_loader))
fwd_end = time()
fwd_time = fwd_end - start
# loss.backward() # loss.backward()
optimizer.backward(loss) optimizer.backward(loss)
bwd_end = time()
bwd_time = bwd_end - fwd_end
print(f"training loss: {loss.item()}") # print(f"training loss: {loss.item()}")
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5) torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
# optim.step() # optim.step()
# optim.zero_grad() # optim.zero_grad()
optimizer.step() optimizer.step()
optim_time = time() - bwd_end
step_time = time() - start
step_tflops = get_tflops_func(step_time)
logger.info(
f"[{i + 1}/{NUM_BATCHES}] Loss:{loss.item():.3f}, Step time: {step_time:.3f}s, TFLOPS: {get_tflops_func(step_time):.3f}, FWD time: {fwd_time:.3f}s, BWD time: {bwd_time:.3f}s, OPTIM time: {optim_time:.3f}s",
ranks=[0],
)
if i >= WARMUP_BATCHES:
tflops_list.append(step_tflops)
else: else:
for __ in range(GRADIENT_ACCUMULATE_EVERY): for __ in range(GRADIENT_ACCUMULATE_EVERY):
loss = model(next(train_loader)) loss = model(next(train_loader))
@ -234,6 +262,11 @@ for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10.0, desc="training"):
optim.step() optim.step()
optim.zero_grad() optim.zero_grad()
tflops_list.sort()
median_index = ((NUM_BATCHES - WARMUP_BATCHES) >> 1) + WARMUP_BATCHES
logger.info(f"Median TFLOPS is {tflops_list[median_index]:.3f}")
# TODO # TODO
# if i % VALIDATE_EVERY == 0: # if i % VALIDATE_EVERY == 0:
# model.eval() # model.eval()