mirror of https://github.com/hpcaitech/ColossalAI
[examples] adding tflops to PaLM (#2365)
parent
93f62dd152
commit
fe0f7970a2
|
@ -1,9 +1,11 @@
|
||||||
import gzip
|
import gzip
|
||||||
import random
|
import random
|
||||||
|
from time import time
|
||||||
|
from functools import partial
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.optim as optim
|
import torch.optim as optim
|
||||||
|
import torch.nn as nn
|
||||||
import tqdm
|
import tqdm
|
||||||
from packaging import version
|
from packaging import version
|
||||||
from palm_pytorch import PaLM
|
from palm_pytorch import PaLM
|
||||||
|
@ -21,7 +23,8 @@ from colossalai.utils.model.colo_init_context import ColoInitContext
|
||||||
|
|
||||||
# constants
|
# constants
|
||||||
|
|
||||||
NUM_BATCHES = int(1000)
|
NUM_BATCHES = int(100)
|
||||||
|
WARMUP_BATCHES = 1
|
||||||
GRADIENT_ACCUMULATE_EVERY = 1
|
GRADIENT_ACCUMULATE_EVERY = 1
|
||||||
LEARNING_RATE = 2e-4
|
LEARNING_RATE = 2e-4
|
||||||
VALIDATE_EVERY = 100
|
VALIDATE_EVERY = 100
|
||||||
|
@ -76,10 +79,18 @@ def cycle(loader):
|
||||||
def decode_token(token):
|
def decode_token(token):
|
||||||
return str(chr(max(32, token)))
|
return str(chr(max(32, token)))
|
||||||
|
|
||||||
|
def get_tflops(model_numel, batch_size, seq_len, step_time):
|
||||||
|
return model_numel * batch_size * seq_len * 8 / 1e12 / (step_time + 1e-12)
|
||||||
|
|
||||||
def decode_tokens(tokens):
|
def decode_tokens(tokens):
|
||||||
return "".join(list(map(decode_token, tokens)))
|
return "".join(list(map(decode_token, tokens)))
|
||||||
|
|
||||||
|
def get_model_size(model: nn.Module):
|
||||||
|
total_numel = 0
|
||||||
|
for module in model.modules():
|
||||||
|
for p in module.parameters(recurse=False):
|
||||||
|
total_numel += p.numel()
|
||||||
|
return total_numel
|
||||||
|
|
||||||
# Gemini + ZeRO DDP
|
# Gemini + ZeRO DDP
|
||||||
def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placememt_policy: str = "auto"):
|
def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placememt_policy: str = "auto"):
|
||||||
|
@ -143,7 +154,6 @@ def tensor_parallelize(model: torch.nn.Module, pg: ProcessGroup):
|
||||||
split_param_row_tp1d(param, pg) # row slice
|
split_param_row_tp1d(param, pg) # row slice
|
||||||
else:
|
else:
|
||||||
param.set_dist_spec(ReplicaSpec())
|
param.set_dist_spec(ReplicaSpec())
|
||||||
|
|
||||||
param.visited = True
|
param.visited = True
|
||||||
|
|
||||||
|
|
||||||
|
@ -152,6 +162,7 @@ if args.distplan not in ["colossalai", "pytorch"]:
|
||||||
raise TypeError(f"{args.distplan} is error")
|
raise TypeError(f"{args.distplan} is error")
|
||||||
disable_existing_loggers()
|
disable_existing_loggers()
|
||||||
colossalai.launch_from_torch(config={})
|
colossalai.launch_from_torch(config={})
|
||||||
|
logger = get_dist_logger()
|
||||||
|
|
||||||
with gzip.open("./data/enwik8.gz") as file:
|
with gzip.open("./data/enwik8.gz") as file:
|
||||||
X = np.fromstring(file.read(int(95e6)), dtype=np.uint8)
|
X = np.fromstring(file.read(int(95e6)), dtype=np.uint8)
|
||||||
|
@ -188,7 +199,7 @@ if args.distplan == "colossalai":
|
||||||
ctx = ColoInitContext(device='cpu', default_dist_spec=default_dist_spec, default_pg=default_pg)
|
ctx = ColoInitContext(device='cpu', default_dist_spec=default_dist_spec, default_pg=default_pg)
|
||||||
|
|
||||||
with ctx:
|
with ctx:
|
||||||
model = PaLM(num_tokens=256, dim=512, depth=8)
|
model = PaLM(num_tokens=50304, dim=4096, depth=64)
|
||||||
model = AutoregressiveWrapper(model, max_seq_len=SEQ_LEN)
|
model = AutoregressiveWrapper(model, max_seq_len=SEQ_LEN)
|
||||||
|
|
||||||
pg = default_pg
|
pg = default_pg
|
||||||
|
@ -205,25 +216,42 @@ else:
|
||||||
model.cuda()
|
model.cuda()
|
||||||
optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
|
optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
|
||||||
|
|
||||||
|
# model is shared after TP
|
||||||
|
numel = get_model_size(model)
|
||||||
|
get_tflops_func = partial(get_tflops, numel, args.batch_size, SEQ_LEN)
|
||||||
|
|
||||||
# training
|
# training
|
||||||
model.train()
|
model.train()
|
||||||
|
tflops_list = []
|
||||||
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10.0, desc="training"):
|
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10.0, desc="training"):
|
||||||
|
|
||||||
if args.distplan == "colossalai":
|
if args.distplan == "colossalai":
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
start = time()
|
||||||
loss = model(next(train_loader))
|
loss = model(next(train_loader))
|
||||||
|
fwd_end = time()
|
||||||
|
fwd_time = fwd_end - start
|
||||||
# loss.backward()
|
# loss.backward()
|
||||||
optimizer.backward(loss)
|
optimizer.backward(loss)
|
||||||
|
bwd_end = time()
|
||||||
|
bwd_time = bwd_end - fwd_end
|
||||||
|
|
||||||
print(f"training loss: {loss.item()}")
|
# print(f"training loss: {loss.item()}")
|
||||||
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
|
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
|
||||||
# optim.step()
|
# optim.step()
|
||||||
# optim.zero_grad()
|
# optim.zero_grad()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
optim_time = time() - bwd_end
|
||||||
|
step_time = time() - start
|
||||||
|
|
||||||
|
step_tflops = get_tflops_func(step_time)
|
||||||
|
logger.info(
|
||||||
|
f"[{i + 1}/{NUM_BATCHES}] Loss:{loss.item():.3f}, Step time: {step_time:.3f}s, TFLOPS: {get_tflops_func(step_time):.3f}, FWD time: {fwd_time:.3f}s, BWD time: {bwd_time:.3f}s, OPTIM time: {optim_time:.3f}s",
|
||||||
|
ranks=[0],
|
||||||
|
)
|
||||||
|
if i >= WARMUP_BATCHES:
|
||||||
|
tflops_list.append(step_tflops)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
for __ in range(GRADIENT_ACCUMULATE_EVERY):
|
for __ in range(GRADIENT_ACCUMULATE_EVERY):
|
||||||
loss = model(next(train_loader))
|
loss = model(next(train_loader))
|
||||||
|
@ -234,6 +262,11 @@ for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10.0, desc="training"):
|
||||||
optim.step()
|
optim.step()
|
||||||
optim.zero_grad()
|
optim.zero_grad()
|
||||||
|
|
||||||
|
tflops_list.sort()
|
||||||
|
median_index = ((NUM_BATCHES - WARMUP_BATCHES) >> 1) + WARMUP_BATCHES
|
||||||
|
logger.info(f"Median TFLOPS is {tflops_list[median_index]:.3f}")
|
||||||
|
|
||||||
|
|
||||||
# TODO
|
# TODO
|
||||||
# if i % VALIDATE_EVERY == 0:
|
# if i % VALIDATE_EVERY == 0:
|
||||||
# model.eval()
|
# model.eval()
|
||||||
|
|
Loading…
Reference in New Issue