mirror of https://github.com/hpcaitech/ColossalAI
polish code
parent
9cba38b492
commit
e64a05b38b
|
@ -1,22 +1,22 @@
|
||||||
import gzip
|
import gzip
|
||||||
import random
|
import random
|
||||||
from time import time
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
from time import time
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.optim as optim
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
import torch.optim as optim
|
||||||
import tqdm
|
import tqdm
|
||||||
from packaging import version
|
from packaging import version
|
||||||
from palm_pytorch import PaLM
|
from palm_pytorch import PaLM
|
||||||
from palm_pytorch.autoregressive_wrapper import AutoregressiveWrapper
|
from palm_pytorch.autoregressive_wrapper import AutoregressiveWrapper
|
||||||
from torch.nn import functional as F
|
|
||||||
from torch.utils.data import DataLoader, Dataset
|
from torch.utils.data import DataLoader, Dataset
|
||||||
|
|
||||||
import colossalai
|
import colossalai
|
||||||
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
||||||
from colossalai.nn.optimizer.gemini_optimizer import GeminiAdamOptimizer
|
from colossalai.nn.optimizer.gemini_optimizer import GeminiAdamOptimizer
|
||||||
from colossalai.nn.parallel import GeminiDDP, ZeroDDP
|
from colossalai.nn.parallel import ZeroDDP
|
||||||
from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec
|
from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec
|
||||||
from colossalai.utils import MultiTimer, get_current_device
|
from colossalai.utils import MultiTimer, get_current_device
|
||||||
from colossalai.utils.model.colo_init_context import ColoInitContext
|
from colossalai.utils.model.colo_init_context import ColoInitContext
|
||||||
|
@ -69,6 +69,7 @@ def parse_args():
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
return args
|
return args
|
||||||
|
|
||||||
|
|
||||||
# helpers
|
# helpers
|
||||||
def cycle(loader):
|
def cycle(loader):
|
||||||
while True:
|
while True:
|
||||||
|
@ -79,12 +80,15 @@ def cycle(loader):
|
||||||
def decode_token(token):
|
def decode_token(token):
|
||||||
return str(chr(max(32, token)))
|
return str(chr(max(32, token)))
|
||||||
|
|
||||||
|
|
||||||
def get_tflops(model_numel, batch_size, seq_len, step_time):
|
def get_tflops(model_numel, batch_size, seq_len, step_time):
|
||||||
return model_numel * batch_size * seq_len * 8 / 1e12 / (step_time + 1e-12)
|
return model_numel * batch_size * seq_len * 8 / 1e12 / (step_time + 1e-12)
|
||||||
|
|
||||||
|
|
||||||
def decode_tokens(tokens):
|
def decode_tokens(tokens):
|
||||||
return "".join(list(map(decode_token, tokens)))
|
return "".join(list(map(decode_token, tokens)))
|
||||||
|
|
||||||
|
|
||||||
def get_model_size(model: nn.Module):
|
def get_model_size(model: nn.Module):
|
||||||
total_numel = 0
|
total_numel = 0
|
||||||
for module in model.modules():
|
for module in model.modules():
|
||||||
|
@ -92,6 +96,7 @@ def get_model_size(model: nn.Module):
|
||||||
total_numel += p.numel()
|
total_numel += p.numel()
|
||||||
return total_numel
|
return total_numel
|
||||||
|
|
||||||
|
|
||||||
# Gemini + ZeRO DDP
|
# Gemini + ZeRO DDP
|
||||||
def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placememt_policy: str = "auto"):
|
def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placememt_policy: str = "auto"):
|
||||||
cai_version = colossalai.__version__
|
cai_version = colossalai.__version__
|
||||||
|
@ -115,6 +120,7 @@ def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placememt_policy:
|
||||||
raise NotImplemented(f"CAI version {cai_version} is not supported")
|
raise NotImplemented(f"CAI version {cai_version} is not supported")
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
## Parameter Sharding Strategies for Tensor Parallelism
|
## Parameter Sharding Strategies for Tensor Parallelism
|
||||||
def split_param_single_dim_tp1d(dim: int, param: ColoParameter, pg: ProcessGroup):
|
def split_param_single_dim_tp1d(dim: int, param: ColoParameter, pg: ProcessGroup):
|
||||||
spec = (ShardSpec([dim], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
spec = (ShardSpec([dim], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||||
|
@ -128,6 +134,7 @@ def split_param_row_tp1d(param: ColoParameter, pg: ProcessGroup):
|
||||||
def split_param_col_tp1d(param: ColoParameter, pg: ProcessGroup):
|
def split_param_col_tp1d(param: ColoParameter, pg: ProcessGroup):
|
||||||
split_param_single_dim_tp1d(-1, param, pg)
|
split_param_single_dim_tp1d(-1, param, pg)
|
||||||
|
|
||||||
|
|
||||||
# Tensor Parallel
|
# Tensor Parallel
|
||||||
def tensor_parallelize(model: torch.nn.Module, pg: ProcessGroup):
|
def tensor_parallelize(model: torch.nn.Module, pg: ProcessGroup):
|
||||||
"""tensor_parallelize
|
"""tensor_parallelize
|
||||||
|
@ -159,7 +166,7 @@ def tensor_parallelize(model: torch.nn.Module, pg: ProcessGroup):
|
||||||
|
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
if args.distplan not in ["colossalai", "pytorch"]:
|
if args.distplan not in ["colossalai", "pytorch"]:
|
||||||
raise TypeError(f"{args.distplan} is error")
|
raise TypeError(f"{args.distplan} is error")
|
||||||
disable_existing_loggers()
|
disable_existing_loggers()
|
||||||
colossalai.launch_from_torch(config={})
|
colossalai.launch_from_torch(config={})
|
||||||
logger = get_dist_logger()
|
logger = get_dist_logger()
|
||||||
|
@ -216,7 +223,7 @@ else:
|
||||||
model.cuda()
|
model.cuda()
|
||||||
optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
|
optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
|
||||||
|
|
||||||
# model is shared after TP
|
# model is shared after TP
|
||||||
numel = get_model_size(model)
|
numel = get_model_size(model)
|
||||||
get_tflops_func = partial(get_tflops, numel, args.batch_size, SEQ_LEN)
|
get_tflops_func = partial(get_tflops, numel, args.batch_size, SEQ_LEN)
|
||||||
|
|
||||||
|
@ -251,7 +258,7 @@ for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10.0, desc="training"):
|
||||||
)
|
)
|
||||||
if i >= WARMUP_BATCHES:
|
if i >= WARMUP_BATCHES:
|
||||||
tflops_list.append(step_tflops)
|
tflops_list.append(step_tflops)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
for __ in range(GRADIENT_ACCUMULATE_EVERY):
|
for __ in range(GRADIENT_ACCUMULATE_EVERY):
|
||||||
loss = model(next(train_loader))
|
loss = model(next(train_loader))
|
||||||
|
@ -261,18 +268,17 @@ for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10.0, desc="training"):
|
||||||
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
|
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
|
||||||
optim.step()
|
optim.step()
|
||||||
optim.zero_grad()
|
optim.zero_grad()
|
||||||
|
|
||||||
tflops_list.sort()
|
tflops_list.sort()
|
||||||
median_index = ((NUM_BATCHES - WARMUP_BATCHES) >> 1) + WARMUP_BATCHES
|
median_index = ((NUM_BATCHES - WARMUP_BATCHES) >> 1) + WARMUP_BATCHES
|
||||||
logger.info(f"Median TFLOPS is {tflops_list[median_index]:.3f}")
|
logger.info(f"Median TFLOPS is {tflops_list[median_index]:.3f}")
|
||||||
|
|
||||||
|
# TODO
|
||||||
# TODO
|
# if i % VALIDATE_EVERY == 0:
|
||||||
# if i % VALIDATE_EVERY == 0:
|
# model.eval()
|
||||||
# model.eval()
|
# with torch.no_grad():
|
||||||
# with torch.no_grad():
|
# loss = model(next(val_loader))
|
||||||
# loss = model(next(val_loader))
|
# print(f"validation loss: {loss.item()}")
|
||||||
# print(f"validation loss: {loss.item()}")
|
|
||||||
|
|
||||||
# if i % GENERATE_EVERY == 0:
|
# if i % GENERATE_EVERY == 0:
|
||||||
# model.eval()
|
# model.eval()
|
||||||
|
@ -282,4 +288,4 @@ logger.info(f"Median TFLOPS is {tflops_list[median_index]:.3f}")
|
||||||
|
|
||||||
# sample = model.generate(inp[None, ...], GENERATE_LENGTH)
|
# sample = model.generate(inp[None, ...], GENERATE_LENGTH)
|
||||||
# output_str = decode_tokens(sample[0])
|
# output_str = decode_tokens(sample[0])
|
||||||
# print(output_str)
|
# print(output_str)
|
||||||
|
|
Loading…
Reference in New Issue