[example] make palm + GeminiDPP work (#2227)

pull/2228/head
Jiarui Fang 2022-12-29 14:28:31 +08:00 committed by GitHub
parent 63cc77173b
commit 2cdecc9f38
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 39 additions and 56 deletions

View File

@ -1,7 +1,7 @@
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from einops import rearrange from einops import rearrange
from torch import einsum, nn, matmul from torch import einsum, matmul, nn
# normalization # normalization
# they use layernorm without bias, something that pytorch does not offer # they use layernorm without bias, something that pytorch does not offer
@ -86,8 +86,6 @@ def FeedForward(dim, mult=4):
# attention # attention
class Attention(nn.Module): class Attention(nn.Module):
def __init__(self, dim, dim_head=64, heads=8): def __init__(self, dim, dim_head=64, heads=8):
@ -142,8 +140,6 @@ class Attention(nn.Module):
q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim=-1)) q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim=-1))
# split heads # split heads
# they use multi-query single-key-value attention, yet another Noam Shazeer paper # they use multi-query single-key-value attention, yet another Noam Shazeer paper
# they found no performance loss past a certain scale, and more efficient decoding obviously # they found no performance loss past a certain scale, and more efficient decoding obviously
@ -165,7 +161,7 @@ class Attention(nn.Module):
# similarity # similarity
#sim = einsum("b h i d, b j d -> b h i j", q, k) #sim = einsum("b h i d, b j d -> b h i j", q, k)
sim = matmul(q.reshape(b, h*i, d), k.transpose(1,2)) sim = matmul(q.reshape(b, h * i, d), k.transpose(1, 2))
sim = sim.reshape(b, h, i, j) sim = sim.reshape(b, h, i, j)
# causal mask # causal mask
@ -183,7 +179,7 @@ class Attention(nn.Module):
# aggregate values # aggregate values
#out = einsum("b h i j, b j d -> b h i d", attn, v) #out = einsum("b h i j, b j d -> b h i d", attn, v)
out = matmul(attn.reshape(b_, h_*i_, j_), v) out = matmul(attn.reshape(b_, h_ * i_, j_), v)
out = out.reshape(b_, h_, i_, d_) out = out.reshape(b_, h_, i_, d_)
# merge heads # merge heads

View File

@ -1 +1 @@
env OMP_NUM_THREADS=12 torchrun --nproc_per_node 8 --master_port 29501 train.py --config palm_config.py env OMP_NUM_THREADS=12 torchrun --nproc_per_node 4 --master_port 29501 train.py --config palm_config.py

View File

@ -5,38 +5,36 @@ import numpy as np
import torch import torch
import torch.optim as optim import torch.optim as optim
import tqdm import tqdm
from packaging import version
from palm_pytorch import PaLM from palm_pytorch import PaLM
from palm_pytorch.autoregressive_wrapper import AutoregressiveWrapper from palm_pytorch.autoregressive_wrapper import AutoregressiveWrapper
from torch.nn import functional as F from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset from torch.utils.data import DataLoader, Dataset
from packaging import version
import colossalai import colossalai
from colossalai.utils.model.colo_init_context import ColoInitContext from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn.optimizer.gemini_optimizer import GeminiAdamOptimizer
from colossalai.nn.parallel import GeminiDDP, ZeroDDP
from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec
from colossalai.utils import MultiTimer, get_current_device from colossalai.utils import MultiTimer, get_current_device
from colossalai.nn.parallel import ZeroDDP from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.nn.optimizer.gemini_optimizer import GeminiAdamOptimizer
from colossalai.nn.parallel import GeminiDDP
from colossalai.logging import disable_existing_loggers, get_dist_logger
# constants # constants
NUM_BATCHES = int(1e5) NUM_BATCHES = int(20)
BATCH_SIZE = 4 BATCH_SIZE = 4
GRADIENT_ACCUMULATE_EVERY = 4 GRADIENT_ACCUMULATE_EVERY = 1
LEARNING_RATE = 2e-4 LEARNING_RATE = 2e-4
VALIDATE_EVERY = 100 VALIDATE_EVERY = 100
GENERATE_EVERY = 500 GENERATE_EVERY = 500
GENERATE_LENGTH = 512 GENERATE_LENGTH = 512
SEQ_LEN = 1024 SEQ_LEN = 1024
TPDEGREE = 2 TPDEGREE = 1
USE_SHARD_INIT = False USE_SHARD_INIT = False
placement = 'cpu' placement = 'cpu'
# helpers # helpers
def cycle(loader): def cycle(loader):
while True: while True:
for data in loader: for data in loader:
@ -50,6 +48,7 @@ def decode_token(token):
def decode_tokens(tokens): def decode_tokens(tokens):
return "".join(list(map(decode_token, tokens))) return "".join(list(map(decode_token, tokens)))
# Gemini + ZeRO DDP # Gemini + ZeRO DDP
def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placememt_policy: str = "auto"): def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placememt_policy: str = "auto"):
cai_version = colossalai.__version__ cai_version = colossalai.__version__
@ -73,6 +72,7 @@ def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placememt_policy:
raise NotImplemented(f"CAI version {cai_version} is not supported") raise NotImplemented(f"CAI version {cai_version} is not supported")
return model return model
# instantiate GPT-like decoder model # instantiate GPT-like decoder model
parser = colossalai.get_default_parser() parser = colossalai.get_default_parser()
@ -80,24 +80,15 @@ args = parser.parse_args()
disable_existing_loggers() disable_existing_loggers()
colossalai.launch_from_torch(config=args.config, seed=42) colossalai.launch_from_torch(config=args.config, seed=42)
# instantiate GPT-like decoder model # instantiate GPT-like decoder model
default_pg = ProcessGroup(tp_degree=TPDEGREE) default_pg = ProcessGroup(tp_degree=TPDEGREE)
default_dist_spec = ShardSpec([-1], [TPDEGREE]) if USE_SHARD_INIT else None default_dist_spec = ShardSpec([-1], [TPDEGREE]) if USE_SHARD_INIT else None
ctx = ColoInitContext(device='cpu', default_dist_spec=default_dist_spec, default_pg=default_pg) ctx = ColoInitContext(device='cpu', default_dist_spec=default_dist_spec, default_pg=default_pg)
with ctx: with ctx:
model = PaLM(num_tokens=256,dim=512,depth=8) model = PaLM(num_tokens=256, dim=512, depth=8)
model = AutoregressiveWrapper(model, max_seq_len=SEQ_LEN) model = AutoregressiveWrapper(model, max_seq_len=SEQ_LEN)
model.cuda()
# prepare enwik8 data
# model = PaLM(num_tokens=256, dim=512, depth=8)
# model = AutoregressiveWrapper(model, max_seq_len=SEQ_LEN)
# model.cuda()
with gzip.open("./data/enwik8.gz") as file: with gzip.open("./data/enwik8.gz") as file:
X = np.fromstring(file.read(int(95e6)), dtype=np.uint8) X = np.fromstring(file.read(int(95e6)), dtype=np.uint8)
@ -129,46 +120,42 @@ val_loader = cycle(DataLoader(val_dataset, batch_size=BATCH_SIZE))
#tensor_parallelize(model, pg) #tensor_parallelize(model, pg)
pg = default_pg pg = default_pg
# model = GeminiDDP(model,
# device=get_current_device(),
# placement_policy="auto",
# pin_memory=True,
# search_range_mb=32)
model = gemini_zero_dpp(model, pg, placement) model = gemini_zero_dpp(model, pg, placement)
#optimizer #optimizer
optimizer = GeminiAdamOptimizer(model, lr=1e-7, initial_scale=2**5) optimizer = GeminiAdamOptimizer(model, lr=1e-7, initial_scale=2**5)
#optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
# training # training
model.train()
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10.0, desc="training"): for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10.0, desc="training"):
model.train()
for __ in range(GRADIENT_ACCUMULATE_EVERY): optimizer.zero_grad()
loss = model(next(train_loader))
loss.backward() loss = model(next(train_loader))
# loss.backward()
optimizer.backward(loss)
print(f"training loss: {loss.item()}") print(f"training loss: {loss.item()}")
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5) torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
# optim.step() # optim.step()
# optim.zero_grad() # optim.zero_grad()
optimizer.step() optimizer.step()
optimizer.zero_grad()
if i % VALIDATE_EVERY == 0: # TODO
model.eval() # if i % VALIDATE_EVERY == 0:
with torch.no_grad(): # model.eval()
loss = model(next(val_loader)) # with torch.no_grad():
print(f"validation loss: {loss.item()}") # loss = model(next(val_loader))
# print(f"validation loss: {loss.item()}")
if i % GENERATE_EVERY == 0: # if i % GENERATE_EVERY == 0:
model.eval() # model.eval()
inp = random.choice(val_dataset)[:-1] # inp = random.choice(val_dataset)[:-1]
prime = decode_tokens(inp) # prime = decode_tokens(inp)
print(f"%s \n\n %s", (prime, "*" * 100)) # print(f"%s \n\n %s", (prime, "*" * 100))
sample = model.generate(inp[None, ...], GENERATE_LENGTH) # sample = model.generate(inp[None, ...], GENERATE_LENGTH)
output_str = decode_tokens(sample[0]) # output_str = decode_tokens(sample[0])
print(output_str) # print(output_str)