|
|
@ -4,7 +4,6 @@ import time
|
|
|
|
from functools import partial
|
|
|
|
from functools import partial
|
|
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
import torch
|
|
|
|
<<<<<<< HEAD
|
|
|
|
|
|
|
|
from tqdm import tqdm
|
|
|
|
from tqdm import tqdm
|
|
|
|
import os
|
|
|
|
import os
|
|
|
|
import time
|
|
|
|
import time
|
|
|
@ -20,15 +19,9 @@ from colossalai.utils.model.colo_init_context import ColoInitContext
|
|
|
|
from colossalai.zero import ZeroOptimizer
|
|
|
|
from colossalai.zero import ZeroOptimizer
|
|
|
|
from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec
|
|
|
|
from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec
|
|
|
|
|
|
|
|
|
|
|
|
=======
|
|
|
|
|
|
|
|
>>>>>>> 52a933e17509c71811e919b165de38cb3d5d6d41
|
|
|
|
|
|
|
|
from arguments import parse_args
|
|
|
|
from arguments import parse_args
|
|
|
|
from evaluation import evaluate
|
|
|
|
from evaluation import evaluate
|
|
|
|
from loss import LossForPretraining
|
|
|
|
from loss import LossForPretraining
|
|
|
|
<<<<<<< HEAD
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from nvidia_bert_dataset_provider import NvidiaBertDatasetProvider
|
|
|
|
|
|
|
|
=======
|
|
|
|
|
|
|
|
from nvidia_bert_dataset_provider import NvidiaBertDatasetProvider
|
|
|
|
from nvidia_bert_dataset_provider import NvidiaBertDatasetProvider
|
|
|
|
from pretrain_utils import get_lr_scheduler, get_model, get_optimizer, save_ckpt
|
|
|
|
from pretrain_utils import get_lr_scheduler, get_model, get_optimizer, save_ckpt
|
|
|
|
from tqdm import tqdm
|
|
|
|
from tqdm import tqdm
|
|
|
@ -37,20 +30,6 @@ from utils.exp_util import get_mem_info, get_tflops, log_args, throughput_calcul
|
|
|
|
from utils.global_vars import get_tensorboard_writer, get_timers, set_global_variables
|
|
|
|
from utils.global_vars import get_tensorboard_writer, get_timers, set_global_variables
|
|
|
|
from utils.logger import Logger
|
|
|
|
from utils.logger import Logger
|
|
|
|
|
|
|
|
|
|
|
|
import colossalai
|
|
|
|
|
|
|
|
import colossalai.nn as col_nn
|
|
|
|
|
|
|
|
from colossalai.context import ParallelMode
|
|
|
|
|
|
|
|
from colossalai.core import global_context as gpc
|
|
|
|
|
|
|
|
from colossalai.nn.optimizer import HybridAdam
|
|
|
|
|
|
|
|
from colossalai.nn.parallel import ZeroDDP
|
|
|
|
|
|
|
|
from colossalai.tensor import ProcessGroup
|
|
|
|
|
|
|
|
from colossalai.utils import get_current_device
|
|
|
|
|
|
|
|
from colossalai.zero import ZeroOptimizer
|
|
|
|
|
|
|
|
from colossalai.zero.gemini import ChunkManager, ColoInitContext, GeminiManager
|
|
|
|
|
|
|
|
from colossalai.zero.legacy import ShardedModelV2, ShardedOptimizerV2, ZeroInitContext
|
|
|
|
|
|
|
|
from colossalai.zero.legacy.shard_utils import TensorShardStrategy
|
|
|
|
|
|
|
|
>>>>>>> 52a933e17509c71811e919b165de38cb3d5d6d41
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main():
|
|
|
|
def main():
|
|
|
|
|
|
|
|
|
|
|
@ -59,13 +38,8 @@ def main():
|
|
|
|
|
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path)
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path)
|
|
|
|
|
|
|
|
|
|
|
|
<<<<<<< HEAD
|
|
|
|
|
|
|
|
# os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
|
|
|
|
# os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
|
|
|
|
|
|
|
|
|
|
|
|
=======
|
|
|
|
|
|
|
|
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
>>>>>>> 52a933e17509c71811e919b165de38cb3d5d6d41
|
|
|
|
|
|
|
|
logger = Logger(os.path.join(args.log_path, launch_time), cuda=torch.cuda.is_available(), debug=args.vscode_debug)
|
|
|
|
logger = Logger(os.path.join(args.log_path, launch_time), cuda=torch.cuda.is_available(), debug=args.vscode_debug)
|
|
|
|
|
|
|
|
|
|
|
|
if args.vscode_debug:
|
|
|
|
if args.vscode_debug:
|
|
|
@ -78,11 +52,7 @@ def main():
|
|
|
|
args.local_rank = -1
|
|
|
|
args.local_rank = -1
|
|
|
|
args.log_interval = 1
|
|
|
|
args.log_interval = 1
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
<<<<<<< HEAD
|
|
|
|
|
|
|
|
colossalai.launch_from_torch(config={}) #args.colossal_config
|
|
|
|
colossalai.launch_from_torch(config={}) #args.colossal_config
|
|
|
|
=======
|
|
|
|
|
|
|
|
colossalai.launch_from_torch(args.colossal_config) # args.colossal_config
|
|
|
|
|
|
|
|
>>>>>>> 52a933e17509c71811e919b165de38cb3d5d6d41
|
|
|
|
|
|
|
|
args.local_rank = int(os.environ["LOCAL_RANK"])
|
|
|
|
args.local_rank = int(os.environ["LOCAL_RANK"])
|
|
|
|
logger.info(
|
|
|
|
logger.info(
|
|
|
|
f'launch_from_torch, world size: {torch.distributed.get_world_size()} | ' +
|
|
|
|
f'launch_from_torch, world size: {torch.distributed.get_world_size()} | ' +
|
|
|
@ -93,17 +63,11 @@ def main():
|
|
|
|
args.tokenizer = tokenizer
|
|
|
|
args.tokenizer = tokenizer
|
|
|
|
args.logger = logger
|
|
|
|
args.logger = logger
|
|
|
|
set_global_variables(launch_time, args.tensorboard_path)
|
|
|
|
set_global_variables(launch_time, args.tensorboard_path)
|
|
|
|
<<<<<<< HEAD
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
=======
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
use_zero = hasattr(gpc.config, 'zero')
|
|
|
|
|
|
|
|
>>>>>>> 52a933e17509c71811e919b165de38cb3d5d6d41
|
|
|
|
|
|
|
|
world_size = torch.distributed.get_world_size()
|
|
|
|
world_size = torch.distributed.get_world_size()
|
|
|
|
init_dev = get_current_device()
|
|
|
|
init_dev = get_current_device()
|
|
|
|
|
|
|
|
|
|
|
|
# build model, optimizer and criterion
|
|
|
|
# build model, optimizer and criterion
|
|
|
|
<<<<<<< HEAD
|
|
|
|
|
|
|
|
if args.distplan.startswith("CAI"):
|
|
|
|
if args.distplan.startswith("CAI"):
|
|
|
|
# all param must use the same process group.
|
|
|
|
# all param must use the same process group.
|
|
|
|
world_size = torch.distributed.get_world_size()
|
|
|
|
world_size = torch.distributed.get_world_size()
|
|
|
@ -118,13 +82,6 @@ def main():
|
|
|
|
dtype=torch.half,
|
|
|
|
dtype=torch.half,
|
|
|
|
default_dist_spec=default_dist_spec,
|
|
|
|
default_dist_spec=default_dist_spec,
|
|
|
|
default_pg=shard_pg):
|
|
|
|
default_pg=shard_pg):
|
|
|
|
=======
|
|
|
|
|
|
|
|
if use_zero:
|
|
|
|
|
|
|
|
shard_strategy = TensorShardStrategy()
|
|
|
|
|
|
|
|
with ZeroInitContext(target_device=torch.cuda.current_device(), shard_strategy=shard_strategy,
|
|
|
|
|
|
|
|
shard_param=True):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
>>>>>>> 52a933e17509c71811e919b165de38cb3d5d6d41
|
|
|
|
|
|
|
|
config, model, numel = get_model(args, logger)
|
|
|
|
config, model, numel = get_model(args, logger)
|
|
|
|
|
|
|
|
|
|
|
|
# asign running configurations
|
|
|
|
# asign running configurations
|
|
|
@ -170,14 +127,9 @@ def main():
|
|
|
|
logger.info(f'Model numel: {numel}')
|
|
|
|
logger.info(f'Model numel: {numel}')
|
|
|
|
|
|
|
|
|
|
|
|
get_tflops_func = partial(get_tflops, numel, args.train_micro_batch_size_per_gpu, args.max_seq_length)
|
|
|
|
get_tflops_func = partial(get_tflops, numel, args.train_micro_batch_size_per_gpu, args.max_seq_length)
|
|
|
|
<<<<<<< HEAD
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 144003367 is is the length of the entire dataset
|
|
|
|
# 144003367 is is the length of the entire dataset
|
|
|
|
steps_per_epoch = 144003367 // world_size // args.train_micro_batch_size_per_gpu // args.gradient_accumulation_steps // args.refresh_bucket_size #len(dataloader)
|
|
|
|
steps_per_epoch = 144003367 // world_size // args.train_micro_batch_size_per_gpu // args.gradient_accumulation_steps // args.refresh_bucket_size #len(dataloader)
|
|
|
|
=======
|
|
|
|
|
|
|
|
# len(dataloader)
|
|
|
|
|
|
|
|
steps_per_epoch = 144003367 // world_size // args.train_micro_batch_size_per_gpu // args.gradient_accumulation_steps // args.refresh_bucket_size
|
|
|
|
|
|
|
|
>>>>>>> 52a933e17509c71811e919b165de38cb3d5d6d41
|
|
|
|
|
|
|
|
total_steps = steps_per_epoch * args.epoch
|
|
|
|
total_steps = steps_per_epoch * args.epoch
|
|
|
|
|
|
|
|
|
|
|
|
lr_scheduler = get_lr_scheduler(optimizer, total_steps=total_steps, last_epoch=-1)
|
|
|
|
lr_scheduler = get_lr_scheduler(optimizer, total_steps=total_steps, last_epoch=-1)
|
|
|
@ -204,32 +156,14 @@ def main():
|
|
|
|
start_epoch = o_l_state_dict['epoch']
|
|
|
|
start_epoch = o_l_state_dict['epoch']
|
|
|
|
start_shard = o_l_state_dict['shard'] + 1
|
|
|
|
start_shard = o_l_state_dict['shard'] + 1
|
|
|
|
# global_step = o_l_state_dict['global_step'] + 1
|
|
|
|
# global_step = o_l_state_dict['global_step'] + 1
|
|
|
|
<<<<<<< HEAD
|
|
|
|
|
|
|
|
logger.info(f'resume from epoch {start_epoch} shard {start_shard} step {lr_scheduler.last_epoch} lr {lr_scheduler.get_last_lr()[0]}')
|
|
|
|
logger.info(f'resume from epoch {start_epoch} shard {start_shard} step {lr_scheduler.last_epoch} lr {lr_scheduler.get_last_lr()[0]}')
|
|
|
|
=======
|
|
|
|
|
|
|
|
logger.info(
|
|
|
|
|
|
|
|
f'resume from epoch {start_epoch} shard {start_shard} step {lr_scheduler.last_epoch} lr {lr_scheduler.get_last_lr()[0]}'
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
optimizer = get_optimizer(model, lr=args.lr)
|
|
|
|
|
|
|
|
lr_scheduler = get_lr_scheduler(optimizer, total_steps=total_steps, last_epoch=-1)
|
|
|
|
|
|
|
|
>>>>>>> 52a933e17509c71811e919b165de38cb3d5d6d41
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
criterion = LossForPretraining(config.vocab_size)
|
|
|
|
criterion = LossForPretraining(config.vocab_size)
|
|
|
|
|
|
|
|
|
|
|
|
# build dataloader
|
|
|
|
# build dataloader
|
|
|
|
pretrain_dataset_provider = NvidiaBertDatasetProvider(args)
|
|
|
|
pretrain_dataset_provider = NvidiaBertDatasetProvider(args)
|
|
|
|
|
|
|
|
|
|
|
|
<<<<<<< HEAD
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
=======
|
|
|
|
|
|
|
|
# initialize with colossalai
|
|
|
|
|
|
|
|
engine, _, _, lr_scheduelr = colossalai.initialize(model=model,
|
|
|
|
|
|
|
|
optimizer=optimizer,
|
|
|
|
|
|
|
|
criterion=criterion,
|
|
|
|
|
|
|
|
lr_scheduler=lr_scheduler)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
>>>>>>> 52a933e17509c71811e919b165de38cb3d5d6d41
|
|
|
|
|
|
|
|
logger.info(get_mem_info(prefix='After init model, '))
|
|
|
|
logger.info(get_mem_info(prefix='After init model, '))
|
|
|
|
|
|
|
|
|
|
|
|
best_loss = None
|
|
|
|
best_loss = None
|
|
|
@ -254,15 +188,9 @@ def main():
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
iterator_data = enumerate(dataset_iterator)
|
|
|
|
iterator_data = enumerate(dataset_iterator)
|
|
|
|
|
|
|
|
|
|
|
|
<<<<<<< HEAD
|
|
|
|
|
|
|
|
model.train()
|
|
|
|
model.train()
|
|
|
|
|
|
|
|
|
|
|
|
for step, batch_data in iterator_data:
|
|
|
|
for step, batch_data in iterator_data:
|
|
|
|
=======
|
|
|
|
|
|
|
|
engine.train()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for step, batch_data in iterator_data:
|
|
|
|
|
|
|
|
>>>>>>> 52a933e17509c71811e919b165de38cb3d5d6d41
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# batch_data = pretrain_dataset_provider.get_batch(batch_index)
|
|
|
|
# batch_data = pretrain_dataset_provider.get_batch(batch_index)
|
|
|
|
input_ids = batch_data[0].cuda(f"cuda:{torch.cuda.current_device()}")
|
|
|
|
input_ids = batch_data[0].cuda(f"cuda:{torch.cuda.current_device()}")
|
|
|
@ -271,31 +199,18 @@ def main():
|
|
|
|
mlm_label = batch_data[3].cuda(f"cuda:{torch.cuda.current_device()}")
|
|
|
|
mlm_label = batch_data[3].cuda(f"cuda:{torch.cuda.current_device()}")
|
|
|
|
# nsp_label = batch_data[5].cuda()
|
|
|
|
# nsp_label = batch_data[5].cuda()
|
|
|
|
|
|
|
|
|
|
|
|
<<<<<<< HEAD
|
|
|
|
|
|
|
|
output = model(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
|
|
|
|
output = model(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
|
|
|
|
|
|
|
|
|
|
|
|
loss = criterion(output.logits, mlm_label)
|
|
|
|
loss = criterion(output.logits, mlm_label)
|
|
|
|
=======
|
|
|
|
|
|
|
|
output = engine(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
loss = engine.criterion(output.logits, mlm_label)
|
|
|
|
|
|
|
|
>>>>>>> 52a933e17509c71811e919b165de38cb3d5d6d41
|
|
|
|
|
|
|
|
pretrain_dataset_provider.prefetch_batch()
|
|
|
|
pretrain_dataset_provider.prefetch_batch()
|
|
|
|
|
|
|
|
|
|
|
|
optimizer.backward(loss)
|
|
|
|
optimizer.backward(loss)
|
|
|
|
train_loss += loss.float().item()
|
|
|
|
train_loss += loss.float().item()
|
|
|
|
# if (step + 1) % args.accumulation_step == 0:
|
|
|
|
# if (step + 1) % args.accumulation_step == 0:
|
|
|
|
<<<<<<< HEAD
|
|
|
|
|
|
|
|
optimizer.step()
|
|
|
|
optimizer.step()
|
|
|
|
lr_scheduler.step()
|
|
|
|
lr_scheduler.step()
|
|
|
|
optimizer.zero_grad()
|
|
|
|
optimizer.zero_grad()
|
|
|
|
|
|
|
|
|
|
|
|
=======
|
|
|
|
|
|
|
|
engine.step()
|
|
|
|
|
|
|
|
lr_scheduelr.step()
|
|
|
|
|
|
|
|
engine.zero_grad()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
>>>>>>> 52a933e17509c71811e919b165de38cb3d5d6d41
|
|
|
|
|
|
|
|
global_step += 1
|
|
|
|
global_step += 1
|
|
|
|
|
|
|
|
|
|
|
|
if global_step % args.log_interval == 0 and global_step != 0 \
|
|
|
|
if global_step % args.log_interval == 0 and global_step != 0 \
|
|
|
@ -326,18 +241,10 @@ def main():
|
|
|
|
logger.info(f'epoch {epoch} shard {shard} has cost {timers("shard_time").elapsed() / 60 :.3f} mins')
|
|
|
|
logger.info(f'epoch {epoch} shard {shard} has cost {timers("shard_time").elapsed() / 60 :.3f} mins')
|
|
|
|
logger.info('*' * 100)
|
|
|
|
logger.info('*' * 100)
|
|
|
|
|
|
|
|
|
|
|
|
<<<<<<< HEAD
|
|
|
|
|
|
|
|
eval_loss += evaluate(model, args, logger, global_step, criterion)
|
|
|
|
eval_loss += evaluate(model, args, logger, global_step, criterion)
|
|
|
|
save_ckpt(model, optimizer, lr_scheduler, os.path.join(args.ckpt_path, launch_time, f'epoch-{epoch}_shard-{shard}_' + launch_time), epoch, shard, global_step)
|
|
|
|
save_ckpt(model, optimizer, lr_scheduler, os.path.join(args.ckpt_path, launch_time, f'epoch-{epoch}_shard-{shard}_' + launch_time), epoch, shard, global_step)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
=======
|
|
|
|
|
|
|
|
eval_loss += evaluate(engine, args, logger, global_step)
|
|
|
|
|
|
|
|
save_ckpt(engine.model, optimizer, lr_scheduelr,
|
|
|
|
|
|
|
|
os.path.join(args.ckpt_path, launch_time, f'epoch-{epoch}_shard-{shard}_' + launch_time), epoch,
|
|
|
|
|
|
|
|
shard, global_step)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
>>>>>>> 52a933e17509c71811e919b165de38cb3d5d6d41
|
|
|
|
|
|
|
|
eval_loss /= len(os.listdir(args.data_path_prefix))
|
|
|
|
eval_loss /= len(os.listdir(args.data_path_prefix))
|
|
|
|
logger.info(
|
|
|
|
logger.info(
|
|
|
|
f'epoch {epoch} | shard_length {len(os.listdir(args.data_path_prefix))} | elapsed_time: {timers("epoch_time").elapsed() / 60 :.3f} mins'
|
|
|
|
f'epoch {epoch} | shard_length {len(os.listdir(args.data_path_prefix))} | elapsed_time: {timers("epoch_time").elapsed() / 60 :.3f} mins'
|
|
|
|