[examples] adding tflops to PaLM (#2365)

pull/2424/head
ZijianYY 2023-01-10 16:18:56 +08:00 committed by GitHub
parent 93f62dd152
commit fe0f7970a2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 41 additions and 8 deletions

View File

@ -1,9 +1,11 @@
import gzip
import random
from time import time
from functools import partial
import numpy as np
import torch
import torch.optim as optim
import torch.nn as nn
import tqdm
from packaging import version
from palm_pytorch import PaLM
@ -21,7 +23,8 @@ from colossalai.utils.model.colo_init_context import ColoInitContext
# constants
NUM_BATCHES = int(1000)
NUM_BATCHES = int(100)
WARMUP_BATCHES = 1
GRADIENT_ACCUMULATE_EVERY = 1
LEARNING_RATE = 2e-4
VALIDATE_EVERY = 100
@ -76,10 +79,18 @@ def cycle(loader):
def decode_token(token):
return str(chr(max(32, token)))
def get_tflops(model_numel, batch_size, seq_len, step_time):
return model_numel * batch_size * seq_len * 8 / 1e12 / (step_time + 1e-12)
def decode_tokens(tokens):
return "".join(list(map(decode_token, tokens)))
def get_model_size(model: nn.Module):
total_numel = 0
for module in model.modules():
for p in module.parameters(recurse=False):
total_numel += p.numel()
return total_numel
# Gemini + ZeRO DDP
def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placememt_policy: str = "auto"):
@ -143,7 +154,6 @@ def tensor_parallelize(model: torch.nn.Module, pg: ProcessGroup):
split_param_row_tp1d(param, pg) # row slice
else:
param.set_dist_spec(ReplicaSpec())
param.visited = True
@ -152,6 +162,7 @@ if args.distplan not in ["colossalai", "pytorch"]:
raise TypeError(f"{args.distplan} is error")
disable_existing_loggers()
colossalai.launch_from_torch(config={})
logger = get_dist_logger()
with gzip.open("./data/enwik8.gz") as file:
X = np.fromstring(file.read(int(95e6)), dtype=np.uint8)
@ -188,7 +199,7 @@ if args.distplan == "colossalai":
ctx = ColoInitContext(device='cpu', default_dist_spec=default_dist_spec, default_pg=default_pg)
with ctx:
model = PaLM(num_tokens=256, dim=512, depth=8)
model = PaLM(num_tokens=50304, dim=4096, depth=64)
model = AutoregressiveWrapper(model, max_seq_len=SEQ_LEN)
pg = default_pg
@ -205,25 +216,42 @@ else:
model.cuda()
optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
# model is shared after TP
numel = get_model_size(model)
get_tflops_func = partial(get_tflops, numel, args.batch_size, SEQ_LEN)
# training
model.train()
tflops_list = []
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10.0, desc="training"):
if args.distplan == "colossalai":
optimizer.zero_grad()
start = time()
loss = model(next(train_loader))
fwd_end = time()
fwd_time = fwd_end - start
# loss.backward()
optimizer.backward(loss)
bwd_end = time()
bwd_time = bwd_end - fwd_end
print(f"training loss: {loss.item()}")
# print(f"training loss: {loss.item()}")
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
# optim.step()
# optim.zero_grad()
optimizer.step()
optim_time = time() - bwd_end
step_time = time() - start
step_tflops = get_tflops_func(step_time)
logger.info(
f"[{i + 1}/{NUM_BATCHES}] Loss:{loss.item():.3f}, Step time: {step_time:.3f}s, TFLOPS: {get_tflops_func(step_time):.3f}, FWD time: {fwd_time:.3f}s, BWD time: {bwd_time:.3f}s, OPTIM time: {optim_time:.3f}s",
ranks=[0],
)
if i >= WARMUP_BATCHES:
tflops_list.append(step_tflops)
else:
for __ in range(GRADIENT_ACCUMULATE_EVERY):
loss = model(next(train_loader))
@ -234,6 +262,11 @@ for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10.0, desc="training"):
optim.step()
optim.zero_grad()
tflops_list.sort()
median_index = ((NUM_BATCHES - WARMUP_BATCHES) >> 1) + WARMUP_BATCHES
logger.info(f"Median TFLOPS is {tflops_list[median_index]:.3f}")
# TODO
# if i % VALIDATE_EVERY == 0:
# model.eval()