import transformers import logging from colossalai.nn.lr_scheduler import LinearWarmupLR from transformers import get_linear_schedule_with_warmup from transformers import BertForPreTraining, RobertaForMaskedLM, RobertaConfig from transformers import GPT2Config, GPT2LMHeadModel from transformers import AutoTokenizer, AutoModelForMaskedLM from colossalai.nn.optimizer import FusedAdam from torch.optim import AdamW from colossalai.core import global_context as gpc import torch import os import sys sys.path.append(os.getcwd()) from model.deberta_v2 import DebertaV2ForMaskedLM from model.bert import BertForMaskedLM import torch.nn as nn from collections import OrderedDict __all__ = ['get_model', 'get_optimizer', 'get_lr_scheduler', 'get_dataloader_for_pretraining'] def get_new_state_dict(state_dict, start_index=13): new_state_dict = OrderedDict() for k, v in state_dict.items(): name = k[start_index:] new_state_dict[name] = v return new_state_dict class LMModel(nn.Module): def __init__(self, model, config, args): super().__init__() self.checkpoint = args.checkpoint_activations self.config = config self.model = model if self.checkpoint: self.model.gradient_checkpointing_enable() def forward(self, input_ids, token_type_ids=None, attention_mask=None): # Only return lm_logits return self.model(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask) def get_model(args, logger): if args.mlm == 'bert': config = transformers.BertConfig.from_json_file(args.bert_config) model = BertForMaskedLM(config) elif args.mlm == 'deberta_v2': config = transformers.DebertaV2Config.from_json_file(args.bert_config) model = DebertaV2ForMaskedLM(config) else: raise Exception("Invalid mlm!") if len(args.load_pretrain_model) > 0: assert os.path.exists(args.load_pretrain_model) # load_checkpoint(args.load_pretrain_model, model, strict=False) m_state_dict = torch.load(args.load_pretrain_model, map_location=torch.device(f"cuda:{torch.cuda.current_device()}")) # new_state_dict = get_new_state_dict(m_state_dict) model.load_state_dict(m_state_dict, strict=True) # must insure that every process have identical parameters !!!!!!! logger.info("load model success") numel = sum([p.numel() for p in model.parameters()]) if args.checkpoint_activations: model.gradient_checkpointing_enable() # model = LMModel(model, config, args) return config, model, numel def get_optimizer(model, lr): param_optimizer = list(model.named_parameters()) no_decay = ['bias', 'gamma', 'beta', 'LayerNorm'] # configure the weight decay for bert models optimizer_grouped_parameters = [{ 'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.1 }, { 'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0 }] optimizer = FusedAdam(optimizer_grouped_parameters, lr=lr, betas=[0.9, 0.95]) return optimizer def get_lr_scheduler(optimizer, total_steps, warmup_steps=2000, last_epoch=-1): # warmup_steps = int(total_steps * warmup_ratio) lr_scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps, last_epoch=last_epoch) # lr_scheduler = LinearWarmupLR(optimizer, total_steps=total_steps, warmup_steps=warmup_steps) return lr_scheduler def save_ckpt(model, optimizer, lr_scheduler, path, epoch, shard, global_step): model_path = path + '_pytorch_model.bin' optimizer_lr_path = path + '.op_lrs' checkpoint = {} checkpoint['optimizer'] = optimizer.state_dict() checkpoint['lr_scheduler'] = lr_scheduler.state_dict() checkpoint['epoch'] = epoch checkpoint['shard'] = shard checkpoint['global_step'] = global_step model_state = model.state_dict() #each process must run model.state_dict() if gpc.get_global_rank() == 0: torch.save(checkpoint, optimizer_lr_path) torch.save(model_state, model_path)