mirror of https://github.com/hpcaitech/ColossalAI
added gpt model & benchmark (#95)
parent
01a80cd86d
commit
e5b9f9a08d
|
@ -11,12 +11,7 @@ from colossalai.core import global_context as gpc
|
|||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.nn import Accuracy, CrossEntropyLoss
|
||||
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
||||
from colossalai.trainer import Trainer
|
||||
from colossalai.trainer.hooks import (AccuracyHook, LogMemoryByEpochHook,
|
||||
LogMetricByEpochHook,
|
||||
LogMetricByStepHook,
|
||||
LogTimingByEpochHook, LossHook,
|
||||
LRSchedulerHook, ThroughputHook)
|
||||
from colossalai.trainer import Trainer, hooks
|
||||
from colossalai.utils import MultiTimer, get_dataloader
|
||||
from model_zoo.vit import vit_lite_depth7_patch4_32
|
||||
from torchvision import transforms
|
||||
|
@ -100,22 +95,22 @@ def train_cifar():
|
|||
trainer = Trainer(engine=engine, logger=logger, timer=timer)
|
||||
logger.info("Trainer is built", ranks=[0])
|
||||
|
||||
hooks = [
|
||||
LogMetricByEpochHook(logger=logger),
|
||||
LogMetricByStepHook(),
|
||||
# LogTimingByEpochHook(timer=timer, logger=logger),
|
||||
# LogMemoryByEpochHook(logger=logger),
|
||||
AccuracyHook(accuracy_func=Accuracy()),
|
||||
LossHook(),
|
||||
ThroughputHook(),
|
||||
LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=False)
|
||||
hook_list = [
|
||||
hooks.LogMetricByEpochHook(logger=logger),
|
||||
hooks.LogMetricByStepHook(),
|
||||
# hooks.LogTimingByEpochHook(timer=timer, logger=logger),
|
||||
# hooks.LogMemoryByEpochHook(logger=logger),
|
||||
hooks.AccuracyHook(accuracy_func=Accuracy()),
|
||||
hooks.LossHook(),
|
||||
hooks.ThroughputHook(),
|
||||
hooks.LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=False)
|
||||
]
|
||||
|
||||
logger.info("Train start", ranks=[0])
|
||||
trainer.fit(train_dataloader=train_dataloader,
|
||||
test_dataloader=test_dataloader,
|
||||
epochs=gpc.config.NUM_EPOCHS,
|
||||
hooks=hooks,
|
||||
hooks=hook_list,
|
||||
display_progress=True,
|
||||
test_interval=1)
|
||||
|
||||
|
|
|
@ -0,0 +1,29 @@
|
|||
from colossalai.amp import AMP_TYPE
|
||||
|
||||
VOCAB_SIZE = 50304
|
||||
SEQ_LENGTH = 1024
|
||||
|
||||
TOTAL_BATCH_SIZE = 256
|
||||
LEARNING_RATE = 0.00015
|
||||
WEIGHT_DECAY = 1e-2
|
||||
|
||||
TENSOR_PARALLEL_SIZE = 2
|
||||
TENSOR_PARALLEL_MODE = '1d'
|
||||
|
||||
NUM_EPOCHS = 60
|
||||
WARMUP_EPOCHS = int(NUM_EPOCHS * 0.36)
|
||||
|
||||
parallel = dict(
|
||||
pipeline=1,
|
||||
tensor=dict(mode=TENSOR_PARALLEL_MODE, size=TENSOR_PARALLEL_SIZE),
|
||||
)
|
||||
|
||||
fp16 = dict(mode=AMP_TYPE.TORCH, )
|
||||
|
||||
gradient_accumulation = 2
|
||||
|
||||
BATCH_SIZE = TOTAL_BATCH_SIZE // gradient_accumulation
|
||||
|
||||
clip_grad_norm = 1.0
|
||||
|
||||
LOG_PATH = f"./gpt2_{TENSOR_PARALLEL_MODE}_tp{TENSOR_PARALLEL_SIZE}_bs{BATCH_SIZE}_lr{LEARNING_RATE}_accum{gradient_accumulation}_clip_grad{clip_grad_norm}/"
|
|
@ -0,0 +1,29 @@
|
|||
from colossalai.amp import AMP_TYPE
|
||||
|
||||
VOCAB_SIZE = 50304
|
||||
SEQ_LENGTH = 1024
|
||||
|
||||
TOTAL_BATCH_SIZE = 256
|
||||
LEARNING_RATE = 0.00015
|
||||
WEIGHT_DECAY = 1e-2
|
||||
|
||||
TENSOR_PARALLEL_SIZE = 4
|
||||
TENSOR_PARALLEL_MODE = '2d'
|
||||
|
||||
NUM_EPOCHS = 60
|
||||
WARMUP_EPOCHS = int(NUM_EPOCHS * 0.36)
|
||||
|
||||
parallel = dict(
|
||||
pipeline=1,
|
||||
tensor=dict(mode=TENSOR_PARALLEL_MODE, size=TENSOR_PARALLEL_SIZE),
|
||||
)
|
||||
|
||||
fp16 = dict(mode=AMP_TYPE.TORCH, )
|
||||
|
||||
gradient_accumulation = 1
|
||||
|
||||
BATCH_SIZE = TOTAL_BATCH_SIZE // gradient_accumulation
|
||||
|
||||
clip_grad_norm = 1.0
|
||||
|
||||
LOG_PATH = f"./gpt2_{TENSOR_PARALLEL_MODE}_tp{TENSOR_PARALLEL_SIZE}_bs{BATCH_SIZE}_lr{LEARNING_RATE}_accum{gradient_accumulation}_clip_grad{clip_grad_norm}/"
|
|
@ -0,0 +1,30 @@
|
|||
from colossalai.amp import AMP_TYPE
|
||||
|
||||
VOCAB_SIZE = 50304
|
||||
SEQ_LENGTH = 1024
|
||||
|
||||
TOTAL_BATCH_SIZE = 256
|
||||
LEARNING_RATE = 0.00015
|
||||
WEIGHT_DECAY = 1e-2
|
||||
|
||||
TENSOR_PARALLEL_SIZE = 4
|
||||
DEPTH = 1
|
||||
TENSOR_PARALLEL_MODE = '2.5d'
|
||||
|
||||
NUM_EPOCHS = 60
|
||||
WARMUP_EPOCHS = int(NUM_EPOCHS * 0.36)
|
||||
|
||||
parallel = dict(
|
||||
pipeline=1,
|
||||
tensor=dict(mode=TENSOR_PARALLEL_MODE, size=TENSOR_PARALLEL_SIZE, depth=DEPTH),
|
||||
)
|
||||
|
||||
fp16 = dict(mode=AMP_TYPE.TORCH, )
|
||||
|
||||
gradient_accumulation = 1
|
||||
|
||||
BATCH_SIZE = TOTAL_BATCH_SIZE // gradient_accumulation
|
||||
|
||||
clip_grad_norm = 1.0
|
||||
|
||||
LOG_PATH = f"./gpt2_{TENSOR_PARALLEL_MODE}_tp{TENSOR_PARALLEL_SIZE}_bs{BATCH_SIZE}_lr{LEARNING_RATE}_accum{gradient_accumulation}_clip_grad{clip_grad_norm}/"
|
|
@ -0,0 +1,29 @@
|
|||
from colossalai.amp import AMP_TYPE
|
||||
|
||||
VOCAB_SIZE = 50304
|
||||
SEQ_LENGTH = 1024
|
||||
|
||||
TOTAL_BATCH_SIZE = 256
|
||||
LEARNING_RATE = 0.00015
|
||||
WEIGHT_DECAY = 1e-2
|
||||
|
||||
TENSOR_PARALLEL_SIZE = 8
|
||||
TENSOR_PARALLEL_MODE = '3d'
|
||||
|
||||
NUM_EPOCHS = 60
|
||||
WARMUP_EPOCHS = int(NUM_EPOCHS * 0.36)
|
||||
|
||||
parallel = dict(
|
||||
pipeline=1,
|
||||
tensor=dict(mode=TENSOR_PARALLEL_MODE, size=TENSOR_PARALLEL_SIZE),
|
||||
)
|
||||
|
||||
fp16 = dict(mode=AMP_TYPE.TORCH, )
|
||||
|
||||
gradient_accumulation = 1
|
||||
|
||||
BATCH_SIZE = TOTAL_BATCH_SIZE // gradient_accumulation
|
||||
|
||||
clip_grad_norm = 1.0
|
||||
|
||||
LOG_PATH = f"./gpt2_{TENSOR_PARALLEL_MODE}_tp{TENSOR_PARALLEL_SIZE}_bs{BATCH_SIZE}_lr{LEARNING_RATE}_accum{gradient_accumulation}_clip_grad{clip_grad_norm}/"
|
|
@ -0,0 +1,29 @@
|
|||
from colossalai.amp import AMP_TYPE
|
||||
|
||||
VOCAB_SIZE = 50304
|
||||
SEQ_LENGTH = 1024
|
||||
|
||||
TOTAL_BATCH_SIZE = 256
|
||||
LEARNING_RATE = 0.00015
|
||||
WEIGHT_DECAY = 1e-2
|
||||
|
||||
TENSOR_PARALLEL_SIZE = 1
|
||||
TENSOR_PARALLEL_MODE = None
|
||||
|
||||
NUM_EPOCHS = 60
|
||||
WARMUP_EPOCHS = int(NUM_EPOCHS * 0.36)
|
||||
|
||||
parallel = dict(
|
||||
pipeline=1,
|
||||
tensor=dict(mode=TENSOR_PARALLEL_MODE, size=TENSOR_PARALLEL_SIZE),
|
||||
)
|
||||
|
||||
fp16 = dict(mode=AMP_TYPE.TORCH, )
|
||||
|
||||
gradient_accumulation = 1
|
||||
|
||||
BATCH_SIZE = TOTAL_BATCH_SIZE // gradient_accumulation
|
||||
|
||||
clip_grad_norm = 1.0
|
||||
|
||||
LOG_PATH = f"./gpt2_{TENSOR_PARALLEL_MODE}_tp{TENSOR_PARALLEL_SIZE}_bs{BATCH_SIZE}_lr{LEARNING_RATE}_accum{gradient_accumulation}_clip_grad{clip_grad_norm}/"
|
|
@ -0,0 +1,37 @@
|
|||
import json
|
||||
import os
|
||||
|
||||
import torch
|
||||
from colossalai.registry import DATASETS
|
||||
from torch.utils.data import Dataset
|
||||
from transformers import GPT2Tokenizer
|
||||
|
||||
|
||||
@DATASETS.register_module
|
||||
class WebtextDataset(Dataset):
|
||||
def __init__(self, path, seq_len=1024) -> None:
|
||||
super().__init__()
|
||||
root = os.path.dirname(path)
|
||||
encoded_data_cache_path = os.path.join(root, f'gpt_webtext_{seq_len}.pt')
|
||||
if os.path.isfile(encoded_data_cache_path):
|
||||
seq_len_, data, attention_mask = torch.load(encoded_data_cache_path)
|
||||
if seq_len_ == seq_len:
|
||||
self.data = data
|
||||
self.attention_mask = attention_mask
|
||||
return
|
||||
raw_data = []
|
||||
with open(path) as f:
|
||||
for line in f.readlines():
|
||||
raw_data.append(json.loads(line)['text'])
|
||||
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
|
||||
tokenizer.pad_token = tokenizer.unk_token
|
||||
encoded_data = tokenizer(raw_data, padding=True, truncation=True, max_length=seq_len, return_tensors='pt')
|
||||
self.data = encoded_data['input_ids']
|
||||
self.attention_mask = encoded_data['attention_mask']
|
||||
torch.save((seq_len, self.data, self.attention_mask), encoded_data_cache_path)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def __getitem__(self, index):
|
||||
return (self.data[index], self.attention_mask[index]), self.data[index]
|
|
@ -0,0 +1,105 @@
|
|||
import contextlib
|
||||
import os
|
||||
|
||||
import colossalai
|
||||
import torch
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.engine.schedule import (InterleavedPipelineSchedule, PipelineSchedule)
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.nn import CosineAnnealingWarmupLR
|
||||
from colossalai.trainer import Trainer, hooks
|
||||
from colossalai.utils import MultiTimer, get_dataloader
|
||||
from colossalai.zero import zero3_model_context
|
||||
from model_zoo.gpt import GPTLMLoss, gpt2_small, gpt2_medium, gpt2_large, gpt2_xl
|
||||
|
||||
from data import WebtextDataset
|
||||
|
||||
|
||||
def train_gpt():
|
||||
args = colossalai.get_default_parser().parse_args()
|
||||
# standard launch
|
||||
# colossalai.launch(config=args.config,
|
||||
# rank=args.rank,
|
||||
# world_size=args.world_size,
|
||||
# local_rank=args.local_rank,
|
||||
# host=args.host,
|
||||
# port=args.port)
|
||||
|
||||
# launch from torchrun
|
||||
colossalai.launch_from_torch(config=args.config)
|
||||
|
||||
logger = get_dist_logger()
|
||||
if hasattr(gpc.config, 'LOG_PATH'):
|
||||
if gpc.get_global_rank() == 0:
|
||||
log_path = gpc.config.LOG_PATH
|
||||
if not os.path.exists(log_path):
|
||||
os.mkdir(log_path)
|
||||
logger.log_to_file(log_path)
|
||||
|
||||
train_dataset = WebtextDataset(os.environ['DATA'], seq_len=gpc.config.SEQ_LENGTH)
|
||||
train_dataloader = get_dataloader(train_dataset,
|
||||
seed=42,
|
||||
batch_size=gpc.config.BATCH_SIZE // gpc.data_parallel_size,
|
||||
pin_memory=True,
|
||||
shuffle=True,
|
||||
drop_last=True)
|
||||
logger.info(f'Loaded {len(train_dataset)}/{len(train_dataloader)} samples/batches', ranks=[0])
|
||||
|
||||
# zero3 under test
|
||||
# use_zero3 = hasattr(gpc.config, 'zero') and gpc.config.zero.level == 3
|
||||
# cm = zero3_model_context() if use_zero3 else contextlib.nullcontext()
|
||||
# with cm:
|
||||
# model = gpc.config.model.pop('type')(**gpc.config.model)
|
||||
|
||||
model = gpt2_medium(vocab_size=gpc.config.VOCAB_SIZE,
|
||||
max_position_embeddings=gpc.config.SEQ_LENGTH,
|
||||
checkpoint=True)
|
||||
|
||||
criterion = GPTLMLoss()
|
||||
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=0.00015, weight_decay=1e-2)
|
||||
|
||||
steps_per_epoch = len(train_dataloader) // gpc.config.gradient_accumulation
|
||||
|
||||
lr_scheduler = CosineAnnealingWarmupLR(optimizer=optimizer,
|
||||
total_steps=gpc.config.NUM_EPOCHS * steps_per_epoch,
|
||||
warmup_steps=gpc.config.WARMUP_EPOCHS * steps_per_epoch,
|
||||
eta_min=1e-5)
|
||||
|
||||
engine, train_dataloader, _, lr_scheduler = colossalai.initialize(model=model,
|
||||
optimizer=optimizer,
|
||||
criterion=criterion,
|
||||
train_dataloader=train_dataloader,
|
||||
lr_scheduler=lr_scheduler)
|
||||
|
||||
# pipeline under test
|
||||
# num_model_chunks = getattr(gpc.config.model, 'num_chunks', 1)
|
||||
# if num_model_chunks > 1:
|
||||
# logger.info('Build InterleavedPipelineSchedule', ranks=[0])
|
||||
# schedule = InterleavedPipelineSchedule(gpc.config.NUM_MICRO_BATCHES, num_model_chunks)
|
||||
# else:
|
||||
# logger.info('Build PipelineSchedule', ranks=[0])
|
||||
# schedule = PipelineSchedule(gpc.config.NUM_MICRO_BATCHES)
|
||||
|
||||
timer = MultiTimer()
|
||||
|
||||
trainer = Trainer(engine=engine, logger=logger, timer=timer)
|
||||
|
||||
hook_list = [
|
||||
hooks.LogMetricByEpochHook(logger=logger),
|
||||
hooks.LogMetricByStepHook(),
|
||||
hooks.LossHook(),
|
||||
hooks.ThroughputHook(),
|
||||
hooks.LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=False),
|
||||
# hooks.TensorboardHook(log_dir='./tb_logs', ranks=[0]),
|
||||
# hooks.LogMemoryByEpochHook(logger),
|
||||
# hooks.LogTimingByEpochHook(timer, logger),
|
||||
# hooks.SaveCheckpointHook(checkpoint_dir='./ckpt')
|
||||
]
|
||||
|
||||
logger.info("Training start", ranks=[0])
|
||||
trainer.fit(train_dataloader=train_dataloader, epochs=gpc.config.NUM_EPOCHS, hooks=hook_list, display_progress=True)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
train_gpt()
|
|
@ -14,9 +14,7 @@ from colossalai.core import global_context as gpc
|
|||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.nn import Accuracy, CrossEntropyLoss
|
||||
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
||||
from colossalai.trainer import Trainer
|
||||
from colossalai.trainer.hooks import (AccuracyHook, LogMemoryByEpochHook, LogMetricByEpochHook, LogMetricByStepHook,
|
||||
LogTimingByEpochHook, LossHook, LRSchedulerHook, ThroughputHook)
|
||||
from colossalai.trainer import Trainer, hooks
|
||||
from colossalai.utils import MultiTimer
|
||||
from model_zoo.vit import vit_small_patch16_224
|
||||
from nvidia.dali import types
|
||||
|
@ -185,22 +183,22 @@ def train_imagenet():
|
|||
trainer = Trainer(engine=engine, logger=logger, timer=timer)
|
||||
logger.info("Trainer is built", ranks=[0])
|
||||
|
||||
hooks = [
|
||||
LogMetricByEpochHook(logger=logger),
|
||||
LogMetricByStepHook(),
|
||||
# LogTimingByEpochHook(timer=timer, logger=logger),
|
||||
# LogMemoryByEpochHook(logger=logger),
|
||||
AccuracyHook(accuracy_func=Accuracy()),
|
||||
LossHook(),
|
||||
ThroughputHook(),
|
||||
LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=True)
|
||||
hook_list = [
|
||||
hooks.LogMetricByEpochHook(logger=logger),
|
||||
hooks.LogMetricByStepHook(),
|
||||
# hooks.LogTimingByEpochHook(timer=timer, logger=logger),
|
||||
# hooks.LogMemoryByEpochHook(logger=logger),
|
||||
hooks.AccuracyHook(accuracy_func=Accuracy()),
|
||||
hooks.LossHook(),
|
||||
hooks.ThroughputHook(),
|
||||
hooks.LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=True)
|
||||
]
|
||||
|
||||
logger.info("Train start", ranks=[0])
|
||||
trainer.fit(train_dataloader=train_dataloader,
|
||||
test_dataloader=test_dataloader,
|
||||
epochs=gpc.config.NUM_EPOCHS,
|
||||
hooks=hooks,
|
||||
hooks=hook_list,
|
||||
display_progress=True,
|
||||
test_interval=1)
|
||||
|
||||
|
|
|
@ -14,9 +14,7 @@ from colossalai.core import global_context as gpc
|
|||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.nn import Accuracy, CrossEntropyLoss
|
||||
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
||||
from colossalai.trainer import Trainer
|
||||
from colossalai.trainer.hooks import (AccuracyHook, LogMemoryByEpochHook, LogMetricByEpochHook, LogMetricByStepHook,
|
||||
LogTimingByEpochHook, LossHook, LRSchedulerHook, ThroughputHook)
|
||||
from colossalai.trainer import Trainer, hooks
|
||||
from colossalai.utils import MultiTimer
|
||||
from model_zoo.vit import vit_small_patch16_224
|
||||
from nvidia.dali import types
|
||||
|
@ -185,22 +183,22 @@ def train_imagenet():
|
|||
trainer = Trainer(engine=engine, logger=logger, timer=timer)
|
||||
logger.info("Trainer is built", ranks=[0])
|
||||
|
||||
hooks = [
|
||||
LogMetricByEpochHook(logger=logger),
|
||||
LogMetricByStepHook(),
|
||||
# LogTimingByEpochHook(timer=timer, logger=logger),
|
||||
# LogMemoryByEpochHook(logger=logger),
|
||||
AccuracyHook(accuracy_func=Accuracy()),
|
||||
LossHook(),
|
||||
ThroughputHook(),
|
||||
LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=True)
|
||||
hook_list = [
|
||||
hooks.LogMetricByEpochHook(logger=logger),
|
||||
hooks.LogMetricByStepHook(),
|
||||
# hooks.LogTimingByEpochHook(timer=timer, logger=logger),
|
||||
# hooks.LogMemoryByEpochHook(logger=logger),
|
||||
hooks.AccuracyHook(accuracy_func=Accuracy()),
|
||||
hooks.LossHook(),
|
||||
hooks.ThroughputHook(),
|
||||
hooks.LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=True)
|
||||
]
|
||||
|
||||
logger.info("Train start", ranks=[0])
|
||||
trainer.fit(train_dataloader=train_dataloader,
|
||||
test_dataloader=test_dataloader,
|
||||
epochs=gpc.config.NUM_EPOCHS,
|
||||
hooks=hooks,
|
||||
hooks=hook_list,
|
||||
display_progress=True,
|
||||
test_interval=1)
|
||||
|
||||
|
|
|
@ -0,0 +1 @@
|
|||
from .gpt import *
|
|
@ -0,0 +1,284 @@
|
|||
import math
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
from colossalai import nn as col_nn
|
||||
from colossalai.nn.layer.utils import CheckpointModule
|
||||
from colossalai.registry import LAYERS, MODELS, LOSSES
|
||||
from colossalai.utils import get_current_device
|
||||
from torch import dtype, nn
|
||||
|
||||
__all__ = ['GPT', 'GPTLMLoss', 'gpt2_small', 'gpt2_medium', 'gpt2_large', 'gpt2_xl', 'gpt3']
|
||||
|
||||
|
||||
@LAYERS.register_module
|
||||
class GPTEmbedding(nn.Module):
|
||||
def __init__(self,
|
||||
embedding_dim: int,
|
||||
vocab_size: int,
|
||||
max_position_embeddings: int,
|
||||
num_tokentypes: int = 0,
|
||||
padding_idx: int = 0,
|
||||
dropout: float = 0.,
|
||||
dtype: dtype = None) -> None:
|
||||
super().__init__()
|
||||
self.word_embeddings = col_nn.Embedding(vocab_size, embedding_dim, padding_idx=padding_idx, dtype=dtype)
|
||||
self.position_embeddings = col_nn.Embedding(max_position_embeddings, embedding_dim, dtype=dtype)
|
||||
if num_tokentypes > 0:
|
||||
self.tokentype_embeddings = col_nn.Embedding(num_tokentypes, embedding_dim, dtype=dtype)
|
||||
else:
|
||||
self.tokentype_embeddings = None
|
||||
self.dropout = col_nn.Dropout(dropout)
|
||||
|
||||
@property
|
||||
def word_embedding_weight(self):
|
||||
return self.word_embeddings.weight
|
||||
|
||||
def forward(self, input_ids, position_ids=None, tokentype_ids=None):
|
||||
seq_length = input_ids.size(1)
|
||||
if position_ids is None:
|
||||
position_ids = torch.arange(seq_length, dtype=torch.long, device=get_current_device()).unsqueeze(0)
|
||||
x = self.word_embeddings(input_ids) + self.position_embeddings(position_ids)
|
||||
if self.tokentype_embeddings is not None and tokentype_ids is not None:
|
||||
x = x + self.tokentype_embeddings(tokentype_ids)
|
||||
x = self.dropout(x)
|
||||
return x
|
||||
|
||||
|
||||
@LAYERS.register_module
|
||||
class GPTSelfAttention(nn.Module):
|
||||
def __init__(self,
|
||||
dim: int,
|
||||
num_heads: int,
|
||||
attention_dropout: float,
|
||||
dropout: float,
|
||||
bias: bool = True,
|
||||
dtype: dtype = None) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.attention_head_size = dim // num_heads
|
||||
self.query_key_value = col_nn.Linear(dim, 3 * dim, dtype=dtype, bias=bias)
|
||||
self.attention_dropout = col_nn.Dropout(attention_dropout)
|
||||
self.dense = col_nn.Linear(dim, dim, dtype=dtype, bias=True)
|
||||
self.dropout = col_nn.Dropout(dropout)
|
||||
self.softmax = nn.Softmax(dim=-1)
|
||||
|
||||
def forward(self, x, attention_mask=None):
|
||||
qkv = self.query_key_value(x)
|
||||
all_head_size = qkv.shape[-1] // 3
|
||||
num_attention_heads = all_head_size // self.attention_head_size
|
||||
new_qkv_shape = qkv.shape[:-1] + \
|
||||
(num_attention_heads, 3 * self.attention_head_size)
|
||||
qkv = qkv.view(new_qkv_shape)
|
||||
qkv = qkv.permute((0, 2, 1, 3))
|
||||
q, k, v = torch.chunk(qkv, 3, dim=-1)
|
||||
|
||||
x = torch.matmul(q, k.transpose(-1, -2))
|
||||
x = x / math.sqrt(self.attention_head_size)
|
||||
|
||||
# causal mask
|
||||
q_len, k_len = q.size(-2), k.size(-2)
|
||||
causal_mask = torch.tril(torch.ones((q_len, k_len), dtype=torch.uint8,
|
||||
device=get_current_device())).view(1, 1, q_len, k_len).bool()
|
||||
x = torch.where(causal_mask, x, torch.tensor(-1e4, dtype=x.dtype, device=get_current_device()))
|
||||
|
||||
if attention_mask is not None:
|
||||
x = x + attention_mask
|
||||
x = self.softmax(x)
|
||||
x = self.attention_dropout(x)
|
||||
|
||||
x = torch.matmul(x, v)
|
||||
x = x.transpose(1, 2)
|
||||
new_context_layer_shape = x.size()[:-2] + (all_head_size, )
|
||||
x = x.reshape(new_context_layer_shape)
|
||||
|
||||
x = self.dense(x)
|
||||
x = self.dropout(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
@LAYERS.register_module
|
||||
class GPTMLP(nn.Module):
|
||||
def __init__(self,
|
||||
dim: int,
|
||||
mlp_ratio: int,
|
||||
activation: Callable,
|
||||
dropout: float,
|
||||
dtype: dtype = None,
|
||||
bias: bool = True):
|
||||
super().__init__()
|
||||
self.dense_1 = col_nn.Linear(dim, mlp_ratio * dim, dtype=dtype, bias=bias)
|
||||
self.activation = activation
|
||||
self.dense_2 = col_nn.Linear(mlp_ratio * dim, dim, dtype=dtype, bias=bias)
|
||||
self.dropout = col_nn.Dropout(dropout)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.dense_1(x)
|
||||
x = self.activation(x)
|
||||
x = self.dense_2(x)
|
||||
x = self.dropout(x)
|
||||
return x
|
||||
|
||||
|
||||
@LAYERS.register_module
|
||||
class GPTBlock(CheckpointModule):
|
||||
def __init__(self,
|
||||
dim: int,
|
||||
num_heads: int,
|
||||
mlp_ratio: int,
|
||||
activation: Callable,
|
||||
attention_dropout: float = 0.,
|
||||
dropout: float = 0.,
|
||||
dtype: dtype = None,
|
||||
bias: bool = True,
|
||||
checkpoint: bool = False):
|
||||
super().__init__()
|
||||
self.norm1 = col_nn.LayerNorm(normalized_shape=dim, eps=1e-6, dtype=dtype)
|
||||
self.attn = GPTSelfAttention(dim=dim,
|
||||
num_heads=num_heads,
|
||||
attention_dropout=attention_dropout,
|
||||
dropout=dropout,
|
||||
bias=bias,
|
||||
dtype=dtype)
|
||||
self.norm2 = col_nn.LayerNorm(normalized_shape=dim, eps=1e-6, dtype=dtype)
|
||||
self.mlp = GPTMLP(dim=dim, mlp_ratio=mlp_ratio, activation=activation, dropout=dropout, dtype=dtype, bias=bias)
|
||||
|
||||
def _forward(self, x, attention_mask=None):
|
||||
x = x + self.attn(self.norm1(x), attention_mask)
|
||||
x = x + self.mlp(self.norm2(x))
|
||||
return x, attention_mask
|
||||
|
||||
|
||||
@LAYERS.register_module
|
||||
class GPTLMHead(nn.Module):
|
||||
def __init__(self,
|
||||
dim: int,
|
||||
vocab_size: int,
|
||||
word_embeeding_weight: nn.Parameter = None,
|
||||
bias: bool = False,
|
||||
dtype: dtype = None) -> None:
|
||||
super().__init__()
|
||||
self.dense = col_nn.Classifier(dim, vocab_size, word_embeeding_weight, bias=bias, dtype=dtype)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.dense(x)
|
||||
return x
|
||||
|
||||
|
||||
@LOSSES.register_module
|
||||
class GPTLMLoss(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.loss = col_nn.CrossEntropyLoss()
|
||||
|
||||
def forward(self, logits, labels):
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
return self.loss(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
||||
|
||||
|
||||
@MODELS.register_module
|
||||
class GPT(nn.Module):
|
||||
def __init__(self,
|
||||
vocab_size: int = 50304,
|
||||
max_position_embeddings: int = 1024,
|
||||
dim: int = 768,
|
||||
num_heads: int = 12,
|
||||
depth: int = 12,
|
||||
mlp_ratio: int = 4,
|
||||
dropout: float = 0.1,
|
||||
embedding_dropout: float = 0.1,
|
||||
attention_dropout: float = 0.1,
|
||||
layernorm_epsilon: float = 1e-5,
|
||||
activation: Callable = nn.functional.gelu,
|
||||
checkpoint: bool = False,
|
||||
dtype: dtype = None,
|
||||
bias: bool = True,
|
||||
padding_idx: int = 0) -> None:
|
||||
super().__init__()
|
||||
self.dtype = dtype
|
||||
self.embed = GPTEmbedding(embedding_dim=dim,
|
||||
vocab_size=vocab_size,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
padding_idx=padding_idx,
|
||||
dropout=embedding_dropout,
|
||||
dtype=dtype)
|
||||
self.blocks = nn.ModuleList([
|
||||
GPTBlock(
|
||||
dim=dim,
|
||||
num_heads=num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
activation=activation,
|
||||
attention_dropout=attention_dropout,
|
||||
dropout=dropout,
|
||||
dtype=dtype,
|
||||
bias=bias,
|
||||
checkpoint=checkpoint,
|
||||
) for _ in range(depth)
|
||||
])
|
||||
|
||||
self.norm = col_nn.LayerNorm(normalized_shape=dim, eps=layernorm_epsilon, dtype=dtype)
|
||||
|
||||
self.head = GPTLMHead(dim=dim,
|
||||
vocab_size=vocab_size,
|
||||
word_embeeding_weight=self.embed.word_embedding_weight,
|
||||
bias=bias,
|
||||
dtype=dtype)
|
||||
|
||||
def forward(self, input_ids, attention_mask=None):
|
||||
# We create a 3D attention mask from a 2D tensor mask.
|
||||
# Sizes are [batch_size, 1, 1, to_seq_length]
|
||||
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
||||
# Adapted from huggingface
|
||||
if attention_mask is not None:
|
||||
batch_size = input_ids.shape[0]
|
||||
attention_mask = attention_mask.view(batch_size, -1)
|
||||
attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
||||
attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
|
||||
attention_mask = (1.0 - attention_mask) * -10000.0
|
||||
|
||||
x = self.embed(input_ids)
|
||||
|
||||
for block in self.blocks:
|
||||
x, attention_mask = block(x, attention_mask)
|
||||
|
||||
x = self.head(self.norm(x))
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def _create_gpt_model(**model_kwargs):
|
||||
model = GPT(**model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@MODELS.register_module
|
||||
def gpt2_small(**kwargs):
|
||||
model_kwargs = dict(dim=768, depth=12, num_heads=12, **kwargs)
|
||||
return _create_gpt_model(**model_kwargs)
|
||||
|
||||
|
||||
@MODELS.register_module
|
||||
def gpt2_medium(**kwargs):
|
||||
model_kwargs = dict(dim=1024, depth=24, num_heads=16, **kwargs)
|
||||
return _create_gpt_model(**model_kwargs)
|
||||
|
||||
|
||||
@MODELS.register_module
|
||||
def gpt2_large(**kwargs):
|
||||
model_kwargs = dict(dim=1280, depth=36, num_heads=20, **kwargs)
|
||||
return _create_gpt_model(**model_kwargs)
|
||||
|
||||
|
||||
@MODELS.register_module
|
||||
def gpt2_xl(**kwargs):
|
||||
model_kwargs = dict(dim=1600, depth=48, num_heads=25, **kwargs)
|
||||
return _create_gpt_model(**model_kwargs)
|
||||
|
||||
|
||||
@MODELS.register_module
|
||||
def gpt3(**kwargs):
|
||||
model_kwargs = dict(dim=12288, max_position_embeddings=2048, depth=96, num_heads=96, **kwargs)
|
||||
return _create_gpt_model(**model_kwargs)
|
|
@ -89,7 +89,7 @@ class ViTEmbedding(nn.Module):
|
|||
|
||||
|
||||
@LAYERS.register_module
|
||||
class ViTSelfAttention(CheckpointModule):
|
||||
class ViTSelfAttention(nn.Module):
|
||||
def __init__(self,
|
||||
dim: int,
|
||||
num_heads: int,
|
||||
|
@ -97,9 +97,8 @@ class ViTSelfAttention(CheckpointModule):
|
|||
dropout: float,
|
||||
bias: bool = True,
|
||||
dtype: dtype = None,
|
||||
checkpoint: bool = False,
|
||||
init_method: str = 'torch'):
|
||||
super().__init__(checkpoint)
|
||||
super().__init__()
|
||||
self.attention_head_size = dim // num_heads
|
||||
self.query_key_value = col_nn.Linear(dim,
|
||||
3 * dim,
|
||||
|
@ -111,7 +110,7 @@ class ViTSelfAttention(CheckpointModule):
|
|||
self.dropout = col_nn.Dropout(dropout)
|
||||
self.softmax = nn.Softmax(dim=-1)
|
||||
|
||||
def _forward(self, x):
|
||||
def forward(self, x):
|
||||
qkv = self.query_key_value(x)
|
||||
all_head_size = qkv.shape[-1] // 3
|
||||
num_attention_heads = all_head_size // self.attention_head_size
|
||||
|
@ -138,7 +137,7 @@ class ViTSelfAttention(CheckpointModule):
|
|||
|
||||
|
||||
@LAYERS.register_module
|
||||
class ViTMLP(CheckpointModule):
|
||||
class ViTMLP(nn.Module):
|
||||
def __init__(self,
|
||||
dim: int,
|
||||
mlp_ratio: int,
|
||||
|
@ -146,9 +145,8 @@ class ViTMLP(CheckpointModule):
|
|||
dropout: float,
|
||||
dtype: dtype = None,
|
||||
bias: bool = True,
|
||||
checkpoint: bool = False,
|
||||
init_method: str = 'torch'):
|
||||
super().__init__(checkpoint)
|
||||
super().__init__()
|
||||
self.dense_1 = col_nn.Linear(dim,
|
||||
mlp_ratio * dim,
|
||||
dtype=dtype,
|
||||
|
@ -163,7 +161,7 @@ class ViTMLP(CheckpointModule):
|
|||
**_init_rules[init_method]['transformer'])
|
||||
self.dropout_2 = col_nn.Dropout(dropout)
|
||||
|
||||
def _forward(self, x):
|
||||
def forward(self, x):
|
||||
x = self.dense_1(x)
|
||||
x = self.activation(x)
|
||||
x = self.dropout_1(x)
|
||||
|
@ -192,22 +190,22 @@ class ViTHead(nn.Module):
|
|||
self.representation = None
|
||||
representation_size = dim
|
||||
|
||||
self.linear = col_nn.Classifier(representation_size,
|
||||
num_classes,
|
||||
dtype=dtype,
|
||||
bias=bias,
|
||||
**_init_rules[init_method]['head'])
|
||||
self.dense = col_nn.Classifier(representation_size,
|
||||
num_classes,
|
||||
dtype=dtype,
|
||||
bias=bias,
|
||||
**_init_rules[init_method]['head'])
|
||||
|
||||
def forward(self, x):
|
||||
x = x[:, 0]
|
||||
if self.representation is not None:
|
||||
x = self.representation(x)
|
||||
x = self.linear(x)
|
||||
x = self.dense(x)
|
||||
return x
|
||||
|
||||
|
||||
@LAYERS.register_module
|
||||
class ViTBlock(nn.Module):
|
||||
class ViTBlock(CheckpointModule):
|
||||
def __init__(self,
|
||||
dim: int,
|
||||
num_heads: int,
|
||||
|
@ -216,32 +214,31 @@ class ViTBlock(nn.Module):
|
|||
attention_dropout: float = 0.,
|
||||
dropout: float = 0.,
|
||||
drop_path: float = 0.,
|
||||
layernorm_epsilon: float = 1e-6,
|
||||
dtype: dtype = None,
|
||||
bias: bool = True,
|
||||
checkpoint: bool = False,
|
||||
init_method: str = 'torch'):
|
||||
super().__init__()
|
||||
self.norm1 = col_nn.LayerNorm(normalized_shape=dim, eps=1e-6, dtype=dtype)
|
||||
super().__init__(checkpoint)
|
||||
self.norm1 = col_nn.LayerNorm(normalized_shape=dim, eps=layernorm_epsilon, dtype=dtype)
|
||||
self.attn = ViTSelfAttention(dim=dim,
|
||||
num_heads=num_heads,
|
||||
attention_dropout=attention_dropout,
|
||||
dropout=dropout,
|
||||
bias=bias,
|
||||
dtype=dtype,
|
||||
checkpoint=checkpoint,
|
||||
init_method=init_method)
|
||||
self.drop_path = col_nn.DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||
self.norm2 = col_nn.LayerNorm(normalized_shape=dim, eps=1e-6, dtype=dtype)
|
||||
self.norm2 = col_nn.LayerNorm(normalized_shape=dim, eps=layernorm_epsilon, dtype=dtype)
|
||||
self.mlp = ViTMLP(dim=dim,
|
||||
mlp_ratio=mlp_ratio,
|
||||
activation=activation,
|
||||
dropout=dropout,
|
||||
dtype=dtype,
|
||||
bias=bias,
|
||||
checkpoint=checkpoint,
|
||||
init_method=init_method)
|
||||
|
||||
def forward(self, x):
|
||||
def _forward(self, x):
|
||||
x = x + self.drop_path(self.attn(self.norm1(x)))
|
||||
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
||||
return x
|
||||
|
@ -261,6 +258,7 @@ class VisionTransformer(nn.Module):
|
|||
attention_dropout: float = 0.,
|
||||
dropout: float = 0.1,
|
||||
drop_path: float = 0.,
|
||||
layernorm_epsilon: float = 1e-6,
|
||||
activation: Callable = nn.functional.gelu,
|
||||
representation_size: int = None,
|
||||
dtype: dtype = None,
|
||||
|
@ -295,7 +293,7 @@ class VisionTransformer(nn.Module):
|
|||
) for i in range(depth)
|
||||
]
|
||||
|
||||
norm = col_nn.LayerNorm(normalized_shape=dim, eps=1e-6, dtype=dtype)
|
||||
norm = col_nn.LayerNorm(normalized_shape=dim, eps=layernorm_epsilon, dtype=dtype)
|
||||
|
||||
head = ViTHead(dim=dim,
|
||||
num_classes=num_classes,
|
||||
|
|
Loading…
Reference in New Issue