|
|
|
@ -1,22 +1,22 @@
|
|
|
|
|
import gzip |
|
|
|
|
import random |
|
|
|
|
from time import time |
|
|
|
|
from functools import partial |
|
|
|
|
from time import time |
|
|
|
|
|
|
|
|
|
import numpy as np |
|
|
|
|
import torch |
|
|
|
|
import torch.optim as optim |
|
|
|
|
import torch.nn as nn |
|
|
|
|
import torch.optim as optim |
|
|
|
|
import tqdm |
|
|
|
|
from packaging import version |
|
|
|
|
from palm_pytorch import PaLM |
|
|
|
|
from palm_pytorch.autoregressive_wrapper import AutoregressiveWrapper |
|
|
|
|
from torch.nn import functional as F |
|
|
|
|
from torch.utils.data import DataLoader, Dataset |
|
|
|
|
|
|
|
|
|
import colossalai |
|
|
|
|
from colossalai.logging import disable_existing_loggers, get_dist_logger |
|
|
|
|
from colossalai.nn.optimizer.gemini_optimizer import GeminiAdamOptimizer |
|
|
|
|
from colossalai.nn.parallel import GeminiDDP, ZeroDDP |
|
|
|
|
from colossalai.nn.parallel import ZeroDDP |
|
|
|
|
from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec |
|
|
|
|
from colossalai.utils import MultiTimer, get_current_device |
|
|
|
|
from colossalai.utils.model.colo_init_context import ColoInitContext |
|
|
|
@ -69,6 +69,7 @@ def parse_args():
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
return args |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# helpers |
|
|
|
|
def cycle(loader): |
|
|
|
|
while True: |
|
|
|
@ -79,12 +80,15 @@ def cycle(loader):
|
|
|
|
|
def decode_token(token): |
|
|
|
|
return str(chr(max(32, token))) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_tflops(model_numel, batch_size, seq_len, step_time): |
|
|
|
|
return model_numel * batch_size * seq_len * 8 / 1e12 / (step_time + 1e-12) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def decode_tokens(tokens): |
|
|
|
|
return "".join(list(map(decode_token, tokens))) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_model_size(model: nn.Module): |
|
|
|
|
total_numel = 0 |
|
|
|
|
for module in model.modules(): |
|
|
|
@ -92,6 +96,7 @@ def get_model_size(model: nn.Module):
|
|
|
|
|
total_numel += p.numel() |
|
|
|
|
return total_numel |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Gemini + ZeRO DDP |
|
|
|
|
def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placememt_policy: str = "auto"): |
|
|
|
|
cai_version = colossalai.__version__ |
|
|
|
@ -115,6 +120,7 @@ def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placememt_policy:
|
|
|
|
|
raise NotImplemented(f"CAI version {cai_version} is not supported") |
|
|
|
|
return model |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
## Parameter Sharding Strategies for Tensor Parallelism |
|
|
|
|
def split_param_single_dim_tp1d(dim: int, param: ColoParameter, pg: ProcessGroup): |
|
|
|
|
spec = (ShardSpec([dim], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) |
|
|
|
@ -128,6 +134,7 @@ def split_param_row_tp1d(param: ColoParameter, pg: ProcessGroup):
|
|
|
|
|
def split_param_col_tp1d(param: ColoParameter, pg: ProcessGroup): |
|
|
|
|
split_param_single_dim_tp1d(-1, param, pg) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Tensor Parallel |
|
|
|
|
def tensor_parallelize(model: torch.nn.Module, pg: ProcessGroup): |
|
|
|
|
"""tensor_parallelize |
|
|
|
@ -216,7 +223,7 @@ else:
|
|
|
|
|
model.cuda() |
|
|
|
|
optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE) |
|
|
|
|
|
|
|
|
|
# model is shared after TP |
|
|
|
|
# model is shared after TP |
|
|
|
|
numel = get_model_size(model) |
|
|
|
|
get_tflops_func = partial(get_tflops, numel, args.batch_size, SEQ_LEN) |
|
|
|
|
|
|
|
|
@ -266,13 +273,12 @@ tflops_list.sort()
|
|
|
|
|
median_index = ((NUM_BATCHES - WARMUP_BATCHES) >> 1) + WARMUP_BATCHES |
|
|
|
|
logger.info(f"Median TFLOPS is {tflops_list[median_index]:.3f}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# TODO |
|
|
|
|
# if i % VALIDATE_EVERY == 0: |
|
|
|
|
# model.eval() |
|
|
|
|
# with torch.no_grad(): |
|
|
|
|
# loss = model(next(val_loader)) |
|
|
|
|
# print(f"validation loss: {loss.item()}") |
|
|
|
|
# TODO |
|
|
|
|
# if i % VALIDATE_EVERY == 0: |
|
|
|
|
# model.eval() |
|
|
|
|
# with torch.no_grad(): |
|
|
|
|
# loss = model(next(val_loader)) |
|
|
|
|
# print(f"validation loss: {loss.item()}") |
|
|
|
|
|
|
|
|
|
# if i % GENERATE_EVERY == 0: |
|
|
|
|
# model.eval() |
|
|
|
|