mirror of https://github.com/hpcaitech/ColossalAI
aibig-modeldata-parallelismdeep-learningdistributed-computingfoundation-modelsheterogeneous-traininghpcinferencelarge-scalemodel-parallelismpipeline-parallelism
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
280 lines
11 KiB
280 lines
11 KiB
import math |
|
import os |
|
import time |
|
from functools import partial |
|
|
|
import torch |
|
from arguments import parse_args |
|
from evaluation import evaluate |
|
from loss import LossForPretraining |
|
from nvidia_bert_dataset_provider import NvidiaBertDatasetProvider |
|
from pretrain_utils import get_lr_scheduler, get_model, get_optimizer, save_ckpt |
|
from tqdm import tqdm |
|
from transformers import AutoTokenizer |
|
from utils.exp_util import get_mem_info, get_tflops, log_args, throughput_calculator |
|
from utils.global_vars import get_tensorboard_writer, get_timers, set_global_variables |
|
from utils.logger import Logger |
|
|
|
import colossalai |
|
from colossalai.accelerator import get_accelerator |
|
from colossalai.context import ParallelMode |
|
from colossalai.nn.parallel import zero_model_wrapper, zero_optim_wrapper |
|
from colossalai.tensor import ProcessGroup, ShardSpec |
|
from colossalai.utils.model.colo_init_context import ColoInitContext |
|
|
|
|
|
def main(): |
|
args = parse_args() |
|
launch_time = time.strftime("%Y-%m-%d-%H:%M:%S", time.localtime()) |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path) |
|
|
|
# os.environ['CUDA_LAUNCH_BLOCKING'] = '1' |
|
|
|
logger = Logger(os.path.join(args.log_path, launch_time), cuda=torch.cuda.is_available(), debug=args.vscode_debug) |
|
|
|
if args.vscode_debug: |
|
colossalai.launch( |
|
rank=args.rank, world_size=args.world_size, host=args.host, port=args.port, backend=args.backend |
|
) |
|
args.local_rank = -1 |
|
args.log_interval = 1 |
|
else: |
|
colossalai.launch_from_torch() # args.colossal_config |
|
args.local_rank = int(os.environ["LOCAL_RANK"]) |
|
logger.info( |
|
f"launch_from_torch, world size: {torch.distributed.get_world_size()} | " |
|
+ f"ParallelMode.MODEL: {ParallelMode.MODEL} | ParallelMode.DATA: {ParallelMode.DATA} | ParallelMode.TENSOR: {ParallelMode.TENSOR}" |
|
) |
|
|
|
log_args(logger, args) |
|
args.tokenizer = tokenizer |
|
args.logger = logger |
|
set_global_variables(launch_time, args.tensorboard_path) |
|
|
|
world_size = torch.distributed.get_world_size() |
|
get_accelerator().get_current_device() |
|
|
|
# build model, optimizer and criterion |
|
if args.distplan.startswith("CAI"): |
|
# all param must use the same process group. |
|
world_size = torch.distributed.get_world_size() |
|
shard_pg = ProcessGroup(tp_degree=world_size) if args.shardinit else None |
|
default_dist_spec = ShardSpec([-1], [world_size]) if args.shardinit else None |
|
|
|
if args.shardinit and args.distplan != "CAI_Gemini": |
|
raise RuntimeError("You can only use shardinit with CAI_Gemini") |
|
|
|
# build GPT model |
|
with ColoInitContext( |
|
device=get_accelerator().get_current_device(), |
|
dtype=torch.half, |
|
default_dist_spec=default_dist_spec, |
|
default_pg=shard_pg, |
|
): |
|
config, model, numel = get_model(args, logger) |
|
|
|
# assign running configurations |
|
gemini_config = None |
|
if args.distplan.startswith("CAI_ZeRO"): |
|
optim_config = dict(reduce_bucket_size=12 * 1024 * 1024, overlap_communication=True, verbose=True) |
|
elif args.distplan == "CAI_Gemini": |
|
gemini_config = dict( |
|
strict_ddp_mode=args.tp_degree == 1, |
|
device=get_accelerator().get_current_device(), |
|
placement_policy=args.placement, |
|
pin_memory=True, |
|
hidden_dim=model.config.hidden_size, |
|
search_range_m=128, |
|
) |
|
optim_config = dict(gpu_margin_mem_ratio=0.0) |
|
else: |
|
raise RuntimeError |
|
|
|
# build a highly optimized gpu/cpu optimizer |
|
optimizer = get_optimizer(model, lr=args.lr) |
|
|
|
if args.distplan == "CAI_ZeRO1": |
|
zero_stage = 1 |
|
elif args.distplan == "CAI_ZeRO2": |
|
zero_stage = 2 |
|
elif args.distplan == "CAI_Gemini": |
|
zero_stage = 3 |
|
else: |
|
raise RuntimeError |
|
|
|
# wrap your model and optimizer |
|
model = zero_model_wrapper(model, zero_stage, gemini_config) |
|
optimizer = zero_optim_wrapper(model, optimizer, optim_config=optim_config) |
|
|
|
logger.info(get_mem_info(prefix="After init optim, ")) |
|
|
|
else: |
|
config, model, numel = get_model(args, logger) |
|
logger.info("no_zero") |
|
|
|
if torch.distributed.get_rank() == 0: |
|
os.mkdir(os.path.join(args.ckpt_path, launch_time)) |
|
|
|
logger.info(f"Model numel: {numel}") |
|
|
|
get_tflops_func = partial(get_tflops, numel, args.train_micro_batch_size_per_gpu, args.max_seq_length) |
|
|
|
# 144003367 is is the length of the entire dataset |
|
# len(dataloader) |
|
steps_per_epoch = ( |
|
144003367 |
|
// world_size |
|
// args.train_micro_batch_size_per_gpu |
|
// args.gradient_accumulation_steps |
|
// args.refresh_bucket_size |
|
) |
|
total_steps = steps_per_epoch * args.epoch |
|
|
|
lr_scheduler = get_lr_scheduler(optimizer, total_steps=total_steps, last_epoch=-1) |
|
|
|
start_epoch = 0 |
|
start_shard = 0 |
|
global_step = 0 |
|
if args.resume_train: |
|
assert os.path.exists(args.load_optimizer_lr) |
|
o_l_state_dict = torch.load(args.load_optimizer_lr, map_location="cpu") |
|
o_l_state_dict["lr_scheduler"]["last_epoch"] = o_l_state_dict["lr_scheduler"]["last_epoch"] - 1 |
|
optimizer.load_state_dict(o_l_state_dict["optimizer"]) |
|
# o_l_state_dict['lr_scheduler']['last_epoch'] |
|
lr_scheduler = get_lr_scheduler( |
|
optimizer, total_steps=total_steps, last_epoch=o_l_state_dict["lr_scheduler"]["last_epoch"] |
|
) |
|
for state in optimizer.state.values(): |
|
for k, v in state.items(): |
|
if isinstance(v, torch.Tensor): |
|
state[k] = v.cuda(f"cuda:{torch.cuda.current_device()}") |
|
# if you want delete the above three code, must move the model to gpu. Because in optimizer.step() |
|
lr_scheduler.load_state_dict(o_l_state_dict["lr_scheduler"]) |
|
|
|
start_epoch = o_l_state_dict["epoch"] |
|
start_shard = o_l_state_dict["shard"] + 1 |
|
# global_step = o_l_state_dict['global_step'] + 1 |
|
logger.info( |
|
f"resume from epoch {start_epoch} shard {start_shard} step {lr_scheduler.last_epoch} lr {lr_scheduler.get_last_lr()[0]}" |
|
) |
|
|
|
criterion = LossForPretraining(config.vocab_size) |
|
|
|
# build dataloader |
|
pretrain_dataset_provider = NvidiaBertDatasetProvider(args) |
|
|
|
logger.info(get_mem_info(prefix="After init model, ")) |
|
|
|
eval_loss = 0 |
|
train_loss = 0 |
|
timers = get_timers() |
|
timers("interval_time").start() |
|
timers("epoch_time").start() |
|
timers("shard_time").start() |
|
|
|
for epoch in range(start_epoch, args.epoch): |
|
for shard in range(start_shard, len(os.listdir(args.data_path_prefix))): |
|
dataset_iterator, total_length = pretrain_dataset_provider.get_shard(shard) |
|
# pretrain_dataset_provider.prefetch_shard(shard + 1) # may cause cpu memory overload |
|
if torch.distributed.get_rank() == 0: |
|
iterator_data = tqdm( |
|
enumerate(dataset_iterator), |
|
total=(total_length // args.train_micro_batch_size_per_gpu // world_size), |
|
colour="cyan", |
|
smoothing=1, |
|
) |
|
else: |
|
iterator_data = enumerate(dataset_iterator) |
|
|
|
model.train() |
|
|
|
for step, batch_data in iterator_data: |
|
# batch_data = pretrain_dataset_provider.get_batch(batch_index) |
|
input_ids = batch_data[0].cuda(f"cuda:{torch.cuda.current_device()}") |
|
attention_mask = batch_data[1].cuda(f"cuda:{torch.cuda.current_device()}") |
|
token_type_ids = batch_data[2].cuda(f"cuda:{torch.cuda.current_device()}") |
|
mlm_label = batch_data[3].cuda(f"cuda:{torch.cuda.current_device()}") |
|
# nsp_label = batch_data[5].cuda() |
|
|
|
output = model(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask) |
|
|
|
loss = criterion(output.logits, mlm_label) |
|
pretrain_dataset_provider.prefetch_batch() |
|
|
|
optimizer.backward(loss) |
|
train_loss += loss.float().item() |
|
# if (step + 1) % args.accumulation_step == 0: |
|
optimizer.step() |
|
lr_scheduler.step() |
|
optimizer.zero_grad() |
|
|
|
global_step += 1 |
|
|
|
if global_step % args.log_interval == 0 and global_step != 0 and torch.distributed.get_rank() == 0: |
|
elapsed_time = timers("interval_time").elapsed(reset=False) |
|
elapsed_time_per_iteration = elapsed_time / global_step |
|
samples_per_sec, tflops, approx_parameters_in_billions = throughput_calculator( |
|
numel, args, config, elapsed_time, global_step, world_size |
|
) |
|
|
|
cur_loss = train_loss / args.log_interval |
|
current_lr = lr_scheduler.get_last_lr()[0] |
|
log_str = ( |
|
f"| epoch: {epoch} | shard: {shard} | step: {global_step} | lr {current_lr:.7f} | elapsed_time: {elapsed_time / 60 :.3f} minutes " |
|
+ f"| mins/batch: {elapsed_time_per_iteration :.3f} seconds | loss: {cur_loss:.7f} | ppl: {math.exp(cur_loss):.3f} | TFLOPS: {get_tflops_func(elapsed_time_per_iteration):.3f} or {tflops:.3f}" |
|
) |
|
logger.info(log_str, print_=False) |
|
|
|
if args.wandb: |
|
tensorboard_log = get_tensorboard_writer() |
|
tensorboard_log.log_train( |
|
{ |
|
"lr": current_lr, |
|
"loss": cur_loss, |
|
"ppl": math.exp(cur_loss), |
|
"mins_batch": elapsed_time_per_iteration, |
|
}, |
|
global_step, |
|
) |
|
|
|
train_loss = 0 |
|
|
|
logger.info(f'epoch {epoch} shard {shard} has cost {timers("shard_time").elapsed() / 60 :.3f} mins') |
|
logger.info("*" * 100) |
|
|
|
eval_loss += evaluate(model, args, logger, global_step, criterion) |
|
save_ckpt( |
|
model, |
|
optimizer, |
|
lr_scheduler, |
|
os.path.join(args.ckpt_path, launch_time, f"epoch-{epoch}_shard-{shard}_" + launch_time), |
|
epoch, |
|
shard, |
|
global_step, |
|
) |
|
|
|
eval_loss /= len(os.listdir(args.data_path_prefix)) |
|
logger.info( |
|
f'epoch {epoch} | shard_length {len(os.listdir(args.data_path_prefix))} | elapsed_time: {timers("epoch_time").elapsed() / 60 :.3f} mins' |
|
+ f"eval_loss: {eval_loss} | ppl: {math.exp(eval_loss)}" |
|
) |
|
logger.info("-" * 100) |
|
if args.wandb and torch.distributed.get_rank() == 0: |
|
tensorboard_log = get_tensorboard_writer() |
|
tensorboard_log.log_eval( |
|
{ |
|
"all_eval_shard_loss": eval_loss, |
|
}, |
|
epoch, |
|
) |
|
start_shard = 0 |
|
eval_loss = 0 |
|
|
|
pretrain_dataset_provider.release_shard() |
|
|
|
logger.info("Congratulation, training has finished!!!") |
|
|
|
|
|
if __name__ == "__main__": |
|
main()
|
|
|