polish code

pull/2484/head
jiaruifang 2023-01-16 14:45:06 +08:00
parent 9cba38b492
commit e64a05b38b
1 changed files with 22 additions and 16 deletions

View File

@ -1,22 +1,22 @@
import gzip import gzip
import random import random
from time import time
from functools import partial from functools import partial
from time import time
import numpy as np import numpy as np
import torch import torch
import torch.optim as optim
import torch.nn as nn import torch.nn as nn
import torch.optim as optim
import tqdm import tqdm
from packaging import version from packaging import version
from palm_pytorch import PaLM from palm_pytorch import PaLM
from palm_pytorch.autoregressive_wrapper import AutoregressiveWrapper from palm_pytorch.autoregressive_wrapper import AutoregressiveWrapper
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset from torch.utils.data import DataLoader, Dataset
import colossalai import colossalai
from colossalai.logging import disable_existing_loggers, get_dist_logger from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn.optimizer.gemini_optimizer import GeminiAdamOptimizer from colossalai.nn.optimizer.gemini_optimizer import GeminiAdamOptimizer
from colossalai.nn.parallel import GeminiDDP, ZeroDDP from colossalai.nn.parallel import ZeroDDP
from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec
from colossalai.utils import MultiTimer, get_current_device from colossalai.utils import MultiTimer, get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext from colossalai.utils.model.colo_init_context import ColoInitContext
@ -69,6 +69,7 @@ def parse_args():
args = parser.parse_args() args = parser.parse_args()
return args return args
# helpers # helpers
def cycle(loader): def cycle(loader):
while True: while True:
@ -79,12 +80,15 @@ def cycle(loader):
def decode_token(token): def decode_token(token):
return str(chr(max(32, token))) return str(chr(max(32, token)))
def get_tflops(model_numel, batch_size, seq_len, step_time): def get_tflops(model_numel, batch_size, seq_len, step_time):
return model_numel * batch_size * seq_len * 8 / 1e12 / (step_time + 1e-12) return model_numel * batch_size * seq_len * 8 / 1e12 / (step_time + 1e-12)
def decode_tokens(tokens): def decode_tokens(tokens):
return "".join(list(map(decode_token, tokens))) return "".join(list(map(decode_token, tokens)))
def get_model_size(model: nn.Module): def get_model_size(model: nn.Module):
total_numel = 0 total_numel = 0
for module in model.modules(): for module in model.modules():
@ -92,6 +96,7 @@ def get_model_size(model: nn.Module):
total_numel += p.numel() total_numel += p.numel()
return total_numel return total_numel
# Gemini + ZeRO DDP # Gemini + ZeRO DDP
def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placememt_policy: str = "auto"): def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placememt_policy: str = "auto"):
cai_version = colossalai.__version__ cai_version = colossalai.__version__
@ -115,6 +120,7 @@ def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placememt_policy:
raise NotImplemented(f"CAI version {cai_version} is not supported") raise NotImplemented(f"CAI version {cai_version} is not supported")
return model return model
## Parameter Sharding Strategies for Tensor Parallelism ## Parameter Sharding Strategies for Tensor Parallelism
def split_param_single_dim_tp1d(dim: int, param: ColoParameter, pg: ProcessGroup): def split_param_single_dim_tp1d(dim: int, param: ColoParameter, pg: ProcessGroup):
spec = (ShardSpec([dim], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) spec = (ShardSpec([dim], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
@ -128,6 +134,7 @@ def split_param_row_tp1d(param: ColoParameter, pg: ProcessGroup):
def split_param_col_tp1d(param: ColoParameter, pg: ProcessGroup): def split_param_col_tp1d(param: ColoParameter, pg: ProcessGroup):
split_param_single_dim_tp1d(-1, param, pg) split_param_single_dim_tp1d(-1, param, pg)
# Tensor Parallel # Tensor Parallel
def tensor_parallelize(model: torch.nn.Module, pg: ProcessGroup): def tensor_parallelize(model: torch.nn.Module, pg: ProcessGroup):
"""tensor_parallelize """tensor_parallelize
@ -159,7 +166,7 @@ def tensor_parallelize(model: torch.nn.Module, pg: ProcessGroup):
args = parse_args() args = parse_args()
if args.distplan not in ["colossalai", "pytorch"]: if args.distplan not in ["colossalai", "pytorch"]:
raise TypeError(f"{args.distplan} is error") raise TypeError(f"{args.distplan} is error")
disable_existing_loggers() disable_existing_loggers()
colossalai.launch_from_torch(config={}) colossalai.launch_from_torch(config={})
logger = get_dist_logger() logger = get_dist_logger()
@ -216,7 +223,7 @@ else:
model.cuda() model.cuda()
optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE) optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
# model is shared after TP # model is shared after TP
numel = get_model_size(model) numel = get_model_size(model)
get_tflops_func = partial(get_tflops, numel, args.batch_size, SEQ_LEN) get_tflops_func = partial(get_tflops, numel, args.batch_size, SEQ_LEN)
@ -251,7 +258,7 @@ for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10.0, desc="training"):
) )
if i >= WARMUP_BATCHES: if i >= WARMUP_BATCHES:
tflops_list.append(step_tflops) tflops_list.append(step_tflops)
else: else:
for __ in range(GRADIENT_ACCUMULATE_EVERY): for __ in range(GRADIENT_ACCUMULATE_EVERY):
loss = model(next(train_loader)) loss = model(next(train_loader))
@ -261,18 +268,17 @@ for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10.0, desc="training"):
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5) torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
optim.step() optim.step()
optim.zero_grad() optim.zero_grad()
tflops_list.sort() tflops_list.sort()
median_index = ((NUM_BATCHES - WARMUP_BATCHES) >> 1) + WARMUP_BATCHES median_index = ((NUM_BATCHES - WARMUP_BATCHES) >> 1) + WARMUP_BATCHES
logger.info(f"Median TFLOPS is {tflops_list[median_index]:.3f}") logger.info(f"Median TFLOPS is {tflops_list[median_index]:.3f}")
# TODO
# TODO # if i % VALIDATE_EVERY == 0:
# if i % VALIDATE_EVERY == 0: # model.eval()
# model.eval() # with torch.no_grad():
# with torch.no_grad(): # loss = model(next(val_loader))
# loss = model(next(val_loader)) # print(f"validation loss: {loss.item()}")
# print(f"validation loss: {loss.item()}")
# if i % GENERATE_EVERY == 0: # if i % GENERATE_EVERY == 0:
# model.eval() # model.eval()
@ -282,4 +288,4 @@ logger.info(f"Median TFLOPS is {tflops_list[median_index]:.3f}")
# sample = model.generate(inp[None, ...], GENERATE_LENGTH) # sample = model.generate(inp[None, ...], GENERATE_LENGTH)
# output_str = decode_tokens(sample[0]) # output_str = decode_tokens(sample[0])
# print(output_str) # print(output_str)