mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
125 lines
4.3 KiB
125 lines
4.3 KiB
import logging
|
|
import os
|
|
import sys
|
|
|
|
import torch
|
|
import transformers
|
|
from torch.optim import AdamW
|
|
from transformers import (
|
|
AutoModelForMaskedLM,
|
|
AutoTokenizer,
|
|
BertForPreTraining,
|
|
GPT2Config,
|
|
GPT2LMHeadModel,
|
|
RobertaConfig,
|
|
RobertaForMaskedLM,
|
|
get_linear_schedule_with_warmup,
|
|
)
|
|
|
|
from colossalai.core import global_context as gpc
|
|
from colossalai.nn.lr_scheduler import LinearWarmupLR
|
|
from colossalai.nn.optimizer import FusedAdam, HybridAdam
|
|
|
|
sys.path.append(os.getcwd())
|
|
from collections import OrderedDict
|
|
|
|
import torch.nn as nn
|
|
from model.bert import BertForMaskedLM
|
|
from model.deberta_v2 import DebertaV2ForMaskedLM
|
|
|
|
__all__ = ['get_model', 'get_optimizer', 'get_lr_scheduler', 'get_dataloader_for_pretraining']
|
|
|
|
|
|
def get_new_state_dict(state_dict, start_index=13):
|
|
new_state_dict = OrderedDict()
|
|
for k, v in state_dict.items():
|
|
name = k[start_index:]
|
|
new_state_dict[name] = v
|
|
return new_state_dict
|
|
|
|
|
|
class LMModel(nn.Module):
|
|
|
|
def __init__(self, model, config, args):
|
|
super().__init__()
|
|
|
|
self.checkpoint = args.checkpoint_activations
|
|
self.config = config
|
|
self.model = model
|
|
if self.checkpoint:
|
|
self.model.gradient_checkpointing_enable()
|
|
|
|
def forward(self, input_ids, token_type_ids=None, attention_mask=None):
|
|
# Only return lm_logits
|
|
return self.model(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
|
|
|
|
|
|
def get_model(args, logger):
|
|
|
|
if args.mlm == 'bert':
|
|
config = transformers.BertConfig.from_json_file(args.bert_config)
|
|
model = BertForMaskedLM(config)
|
|
elif args.mlm == 'deberta_v2':
|
|
config = transformers.DebertaV2Config.from_json_file(args.bert_config)
|
|
model = DebertaV2ForMaskedLM(config)
|
|
else:
|
|
raise Exception("Invalid mlm!")
|
|
|
|
if len(args.load_pretrain_model) > 0:
|
|
assert os.path.exists(args.load_pretrain_model)
|
|
# load_checkpoint(args.load_pretrain_model, model, strict=False)
|
|
m_state_dict = torch.load(args.load_pretrain_model,
|
|
map_location=torch.device(f"cuda:{torch.cuda.current_device()}"))
|
|
# new_state_dict = get_new_state_dict(m_state_dict)
|
|
model.load_state_dict(m_state_dict,
|
|
strict=True) # must insure that every process have identical parameters !!!!!!!
|
|
logger.info("load model success")
|
|
|
|
numel = sum([p.numel() for p in model.parameters()])
|
|
if args.checkpoint_activations:
|
|
model.gradient_checkpointing_enable()
|
|
# model = LMModel(model, config, args)
|
|
|
|
return config, model, numel
|
|
|
|
|
|
def get_optimizer(model, lr):
|
|
param_optimizer = list(model.named_parameters())
|
|
no_decay = ['bias', 'gamma', 'beta', 'LayerNorm']
|
|
|
|
# configure the weight decay for bert models
|
|
optimizer_grouped_parameters = [{
|
|
'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
|
|
'weight_decay': 0.1
|
|
}, {
|
|
'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
|
|
'weight_decay': 0.0
|
|
}]
|
|
optimizer = HybridAdam(optimizer_grouped_parameters, lr=lr, betas=[0.9, 0.95])
|
|
return optimizer
|
|
|
|
|
|
def get_lr_scheduler(optimizer, total_steps, warmup_steps=2000, last_epoch=-1):
|
|
# warmup_steps = int(total_steps * warmup_ratio)
|
|
lr_scheduler = get_linear_schedule_with_warmup(optimizer,
|
|
num_warmup_steps=warmup_steps,
|
|
num_training_steps=total_steps,
|
|
last_epoch=last_epoch)
|
|
# lr_scheduler = LinearWarmupLR(optimizer, total_steps=total_steps, warmup_steps=warmup_steps)
|
|
return lr_scheduler
|
|
|
|
|
|
def save_ckpt(model, optimizer, lr_scheduler, path, epoch, shard, global_step):
|
|
model_path = path + '_pytorch_model.bin'
|
|
optimizer_lr_path = path + '.op_lrs'
|
|
checkpoint = {}
|
|
checkpoint['optimizer'] = optimizer.state_dict()
|
|
checkpoint['lr_scheduler'] = lr_scheduler.state_dict()
|
|
checkpoint['epoch'] = epoch
|
|
checkpoint['shard'] = shard
|
|
checkpoint['global_step'] = global_step
|
|
model_state = model.state_dict() #each process must run model.state_dict()
|
|
if gpc.get_global_rank() == 0:
|
|
torch.save(checkpoint, optimizer_lr_path)
|
|
torch.save(model_state, model_path)
|