import os
import sys
import torch
import transformers
from transformers import get_linear_schedule_with_warmup
from colossalai.legacy.core import global_context as gpc
from colossalai.nn.optimizer import HybridAdam
from collections import OrderedDict
import torch.nn as nn
from model.bert import BertForMaskedLM
from model.deberta_v2 import DebertaV2ForMaskedLM
__all__ = ["get_model", "get_optimizer", "get_lr_scheduler", "get_dataloader_for_pretraining"]
def get_new_state_dict(state_dict, start_index=13):
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k[start_index:]
new_state_dict[name] = v
return new_state_dict
class LMModel(nn.Module):
def __init__(self, model, config, args):
self.checkpoint = args.checkpoint_activations
self.config = config
self.model = model
if self.checkpoint:
def forward(self, input_ids, token_type_ids=None, attention_mask=None):
# Only return lm_logits
return self.model(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
def get_model(args, logger):
if args.mlm == "bert":
config = transformers.BertConfig.from_json_file(args.bert_config)
model = BertForMaskedLM(config)
elif args.mlm == "deberta_v2":
config = transformers.DebertaV2Config.from_json_file(args.bert_config)
model = DebertaV2ForMaskedLM(config)
raise Exception("Invalid mlm!")
if len(args.load_pretrain_model) > 0:
assert os.path.exists(args.load_pretrain_model)
# load_checkpoint(args.load_pretrain_model, model, strict=False)
m_state_dict = torch.load(
args.load_pretrain_model, map_location=torch.device(f"cuda:{torch.cuda.current_device()}")
# new_state_dict = get_new_state_dict(m_state_dict)
m_state_dict, strict=True
) # must insure that every process have identical parameters !!!!!!!"load model success")
numel = sum([p.numel() for p in model.parameters()])
if args.checkpoint_activations:
# model = LMModel(model, config, args)
return config, model, numel
def get_optimizer(model, lr):
param_optimizer = list(model.named_parameters())
no_decay = ["bias", "gamma", "beta", "LayerNorm"]
# configure the weight decay for bert models
optimizer_grouped_parameters = [
{"params": [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], "weight_decay": 0.1},
{"params": [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], "weight_decay": 0.0},
optimizer = HybridAdam(optimizer_grouped_parameters, lr=lr, betas=[0.9, 0.95])
return optimizer
def get_lr_scheduler(optimizer, total_steps, warmup_steps=2000, last_epoch=-1):
# warmup_steps = int(total_steps * warmup_ratio)
lr_scheduler = get_linear_schedule_with_warmup(
optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps, last_epoch=last_epoch
# lr_scheduler = LinearWarmupLR(optimizer, total_steps=total_steps, warmup_steps=warmup_steps)
return lr_scheduler
def save_ckpt(model, optimizer, lr_scheduler, path, epoch, shard, global_step):
model_path = path + "_pytorch_model.bin"
optimizer_lr_path = path + ".op_lrs"
checkpoint = {}
checkpoint["optimizer"] = optimizer.state_dict()
checkpoint["lr_scheduler"] = lr_scheduler.state_dict()
checkpoint["epoch"] = epoch
checkpoint["shard"] = shard
checkpoint["global_step"] = global_step
model_state = model.state_dict() # each process must run model.state_dict()
if gpc.get_global_rank() == 0:, optimizer_lr_path), model_path)