added gpt model & benchmark (#95)

pull/97/head
アマデウス 2021-12-30 14:43:30 +08:00 committed by GitHub
parent 01a80cd86d
commit e5b9f9a08d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 626 additions and 64 deletions

View File

@ -11,12 +11,7 @@ from colossalai.core import global_context as gpc
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.nn import Accuracy, CrossEntropyLoss from colossalai.nn import Accuracy, CrossEntropyLoss
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
from colossalai.trainer import Trainer from colossalai.trainer import Trainer, hooks
from colossalai.trainer.hooks import (AccuracyHook, LogMemoryByEpochHook,
LogMetricByEpochHook,
LogMetricByStepHook,
LogTimingByEpochHook, LossHook,
LRSchedulerHook, ThroughputHook)
from colossalai.utils import MultiTimer, get_dataloader from colossalai.utils import MultiTimer, get_dataloader
from model_zoo.vit import vit_lite_depth7_patch4_32 from model_zoo.vit import vit_lite_depth7_patch4_32
from torchvision import transforms from torchvision import transforms
@ -100,22 +95,22 @@ def train_cifar():
trainer = Trainer(engine=engine, logger=logger, timer=timer) trainer = Trainer(engine=engine, logger=logger, timer=timer)
logger.info("Trainer is built", ranks=[0]) logger.info("Trainer is built", ranks=[0])
hooks = [ hook_list = [
LogMetricByEpochHook(logger=logger), hooks.LogMetricByEpochHook(logger=logger),
LogMetricByStepHook(), hooks.LogMetricByStepHook(),
# LogTimingByEpochHook(timer=timer, logger=logger), # hooks.LogTimingByEpochHook(timer=timer, logger=logger),
# LogMemoryByEpochHook(logger=logger), # hooks.LogMemoryByEpochHook(logger=logger),
AccuracyHook(accuracy_func=Accuracy()), hooks.AccuracyHook(accuracy_func=Accuracy()),
LossHook(), hooks.LossHook(),
ThroughputHook(), hooks.ThroughputHook(),
LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=False) hooks.LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=False)
] ]
logger.info("Train start", ranks=[0]) logger.info("Train start", ranks=[0])
trainer.fit(train_dataloader=train_dataloader, trainer.fit(train_dataloader=train_dataloader,
test_dataloader=test_dataloader, test_dataloader=test_dataloader,
epochs=gpc.config.NUM_EPOCHS, epochs=gpc.config.NUM_EPOCHS,
hooks=hooks, hooks=hook_list,
display_progress=True, display_progress=True,
test_interval=1) test_interval=1)

View File

@ -0,0 +1,29 @@
from colossalai.amp import AMP_TYPE
VOCAB_SIZE = 50304
SEQ_LENGTH = 1024
TOTAL_BATCH_SIZE = 256
LEARNING_RATE = 0.00015
WEIGHT_DECAY = 1e-2
TENSOR_PARALLEL_SIZE = 2
TENSOR_PARALLEL_MODE = '1d'
NUM_EPOCHS = 60
WARMUP_EPOCHS = int(NUM_EPOCHS * 0.36)
parallel = dict(
pipeline=1,
tensor=dict(mode=TENSOR_PARALLEL_MODE, size=TENSOR_PARALLEL_SIZE),
)
fp16 = dict(mode=AMP_TYPE.TORCH, )
gradient_accumulation = 2
BATCH_SIZE = TOTAL_BATCH_SIZE // gradient_accumulation
clip_grad_norm = 1.0
LOG_PATH = f"./gpt2_{TENSOR_PARALLEL_MODE}_tp{TENSOR_PARALLEL_SIZE}_bs{BATCH_SIZE}_lr{LEARNING_RATE}_accum{gradient_accumulation}_clip_grad{clip_grad_norm}/"

View File

@ -0,0 +1,29 @@
from colossalai.amp import AMP_TYPE
VOCAB_SIZE = 50304
SEQ_LENGTH = 1024
TOTAL_BATCH_SIZE = 256
LEARNING_RATE = 0.00015
WEIGHT_DECAY = 1e-2
TENSOR_PARALLEL_SIZE = 4
TENSOR_PARALLEL_MODE = '2d'
NUM_EPOCHS = 60
WARMUP_EPOCHS = int(NUM_EPOCHS * 0.36)
parallel = dict(
pipeline=1,
tensor=dict(mode=TENSOR_PARALLEL_MODE, size=TENSOR_PARALLEL_SIZE),
)
fp16 = dict(mode=AMP_TYPE.TORCH, )
gradient_accumulation = 1
BATCH_SIZE = TOTAL_BATCH_SIZE // gradient_accumulation
clip_grad_norm = 1.0
LOG_PATH = f"./gpt2_{TENSOR_PARALLEL_MODE}_tp{TENSOR_PARALLEL_SIZE}_bs{BATCH_SIZE}_lr{LEARNING_RATE}_accum{gradient_accumulation}_clip_grad{clip_grad_norm}/"

View File

@ -0,0 +1,30 @@
from colossalai.amp import AMP_TYPE
VOCAB_SIZE = 50304
SEQ_LENGTH = 1024
TOTAL_BATCH_SIZE = 256
LEARNING_RATE = 0.00015
WEIGHT_DECAY = 1e-2
TENSOR_PARALLEL_SIZE = 4
DEPTH = 1
TENSOR_PARALLEL_MODE = '2.5d'
NUM_EPOCHS = 60
WARMUP_EPOCHS = int(NUM_EPOCHS * 0.36)
parallel = dict(
pipeline=1,
tensor=dict(mode=TENSOR_PARALLEL_MODE, size=TENSOR_PARALLEL_SIZE, depth=DEPTH),
)
fp16 = dict(mode=AMP_TYPE.TORCH, )
gradient_accumulation = 1
BATCH_SIZE = TOTAL_BATCH_SIZE // gradient_accumulation
clip_grad_norm = 1.0
LOG_PATH = f"./gpt2_{TENSOR_PARALLEL_MODE}_tp{TENSOR_PARALLEL_SIZE}_bs{BATCH_SIZE}_lr{LEARNING_RATE}_accum{gradient_accumulation}_clip_grad{clip_grad_norm}/"

View File

@ -0,0 +1,29 @@
from colossalai.amp import AMP_TYPE
VOCAB_SIZE = 50304
SEQ_LENGTH = 1024
TOTAL_BATCH_SIZE = 256
LEARNING_RATE = 0.00015
WEIGHT_DECAY = 1e-2
TENSOR_PARALLEL_SIZE = 8
TENSOR_PARALLEL_MODE = '3d'
NUM_EPOCHS = 60
WARMUP_EPOCHS = int(NUM_EPOCHS * 0.36)
parallel = dict(
pipeline=1,
tensor=dict(mode=TENSOR_PARALLEL_MODE, size=TENSOR_PARALLEL_SIZE),
)
fp16 = dict(mode=AMP_TYPE.TORCH, )
gradient_accumulation = 1
BATCH_SIZE = TOTAL_BATCH_SIZE // gradient_accumulation
clip_grad_norm = 1.0
LOG_PATH = f"./gpt2_{TENSOR_PARALLEL_MODE}_tp{TENSOR_PARALLEL_SIZE}_bs{BATCH_SIZE}_lr{LEARNING_RATE}_accum{gradient_accumulation}_clip_grad{clip_grad_norm}/"

View File

@ -0,0 +1,29 @@
from colossalai.amp import AMP_TYPE
VOCAB_SIZE = 50304
SEQ_LENGTH = 1024
TOTAL_BATCH_SIZE = 256
LEARNING_RATE = 0.00015
WEIGHT_DECAY = 1e-2
TENSOR_PARALLEL_SIZE = 1
TENSOR_PARALLEL_MODE = None
NUM_EPOCHS = 60
WARMUP_EPOCHS = int(NUM_EPOCHS * 0.36)
parallel = dict(
pipeline=1,
tensor=dict(mode=TENSOR_PARALLEL_MODE, size=TENSOR_PARALLEL_SIZE),
)
fp16 = dict(mode=AMP_TYPE.TORCH, )
gradient_accumulation = 1
BATCH_SIZE = TOTAL_BATCH_SIZE // gradient_accumulation
clip_grad_norm = 1.0
LOG_PATH = f"./gpt2_{TENSOR_PARALLEL_MODE}_tp{TENSOR_PARALLEL_SIZE}_bs{BATCH_SIZE}_lr{LEARNING_RATE}_accum{gradient_accumulation}_clip_grad{clip_grad_norm}/"

37
benchmark/gpt2/data.py Normal file
View File

@ -0,0 +1,37 @@
import json
import os
import torch
from colossalai.registry import DATASETS
from torch.utils.data import Dataset
from transformers import GPT2Tokenizer
@DATASETS.register_module
class WebtextDataset(Dataset):
def __init__(self, path, seq_len=1024) -> None:
super().__init__()
root = os.path.dirname(path)
encoded_data_cache_path = os.path.join(root, f'gpt_webtext_{seq_len}.pt')
if os.path.isfile(encoded_data_cache_path):
seq_len_, data, attention_mask = torch.load(encoded_data_cache_path)
if seq_len_ == seq_len:
self.data = data
self.attention_mask = attention_mask
return
raw_data = []
with open(path) as f:
for line in f.readlines():
raw_data.append(json.loads(line)['text'])
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.unk_token
encoded_data = tokenizer(raw_data, padding=True, truncation=True, max_length=seq_len, return_tensors='pt')
self.data = encoded_data['input_ids']
self.attention_mask = encoded_data['attention_mask']
torch.save((seq_len, self.data, self.attention_mask), encoded_data_cache_path)
def __len__(self):
return len(self.data)
def __getitem__(self, index):
return (self.data[index], self.attention_mask[index]), self.data[index]

105
benchmark/gpt2/train.py Normal file
View File

@ -0,0 +1,105 @@
import contextlib
import os
import colossalai
import torch
from colossalai.core import global_context as gpc
from colossalai.engine.schedule import (InterleavedPipelineSchedule, PipelineSchedule)
from colossalai.logging import get_dist_logger
from colossalai.nn import CosineAnnealingWarmupLR
from colossalai.trainer import Trainer, hooks
from colossalai.utils import MultiTimer, get_dataloader
from colossalai.zero import zero3_model_context
from model_zoo.gpt import GPTLMLoss, gpt2_small, gpt2_medium, gpt2_large, gpt2_xl
from data import WebtextDataset
def train_gpt():
args = colossalai.get_default_parser().parse_args()
# standard launch
# colossalai.launch(config=args.config,
# rank=args.rank,
# world_size=args.world_size,
# local_rank=args.local_rank,
# host=args.host,
# port=args.port)
# launch from torchrun
colossalai.launch_from_torch(config=args.config)
logger = get_dist_logger()
if hasattr(gpc.config, 'LOG_PATH'):
if gpc.get_global_rank() == 0:
log_path = gpc.config.LOG_PATH
if not os.path.exists(log_path):
os.mkdir(log_path)
logger.log_to_file(log_path)
train_dataset = WebtextDataset(os.environ['DATA'], seq_len=gpc.config.SEQ_LENGTH)
train_dataloader = get_dataloader(train_dataset,
seed=42,
batch_size=gpc.config.BATCH_SIZE // gpc.data_parallel_size,
pin_memory=True,
shuffle=True,
drop_last=True)
logger.info(f'Loaded {len(train_dataset)}/{len(train_dataloader)} samples/batches', ranks=[0])
# zero3 under test
# use_zero3 = hasattr(gpc.config, 'zero') and gpc.config.zero.level == 3
# cm = zero3_model_context() if use_zero3 else contextlib.nullcontext()
# with cm:
# model = gpc.config.model.pop('type')(**gpc.config.model)
model = gpt2_medium(vocab_size=gpc.config.VOCAB_SIZE,
max_position_embeddings=gpc.config.SEQ_LENGTH,
checkpoint=True)
criterion = GPTLMLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.00015, weight_decay=1e-2)
steps_per_epoch = len(train_dataloader) // gpc.config.gradient_accumulation
lr_scheduler = CosineAnnealingWarmupLR(optimizer=optimizer,
total_steps=gpc.config.NUM_EPOCHS * steps_per_epoch,
warmup_steps=gpc.config.WARMUP_EPOCHS * steps_per_epoch,
eta_min=1e-5)
engine, train_dataloader, _, lr_scheduler = colossalai.initialize(model=model,
optimizer=optimizer,
criterion=criterion,
train_dataloader=train_dataloader,
lr_scheduler=lr_scheduler)
# pipeline under test
# num_model_chunks = getattr(gpc.config.model, 'num_chunks', 1)
# if num_model_chunks > 1:
# logger.info('Build InterleavedPipelineSchedule', ranks=[0])
# schedule = InterleavedPipelineSchedule(gpc.config.NUM_MICRO_BATCHES, num_model_chunks)
# else:
# logger.info('Build PipelineSchedule', ranks=[0])
# schedule = PipelineSchedule(gpc.config.NUM_MICRO_BATCHES)
timer = MultiTimer()
trainer = Trainer(engine=engine, logger=logger, timer=timer)
hook_list = [
hooks.LogMetricByEpochHook(logger=logger),
hooks.LogMetricByStepHook(),
hooks.LossHook(),
hooks.ThroughputHook(),
hooks.LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=False),
# hooks.TensorboardHook(log_dir='./tb_logs', ranks=[0]),
# hooks.LogMemoryByEpochHook(logger),
# hooks.LogTimingByEpochHook(timer, logger),
# hooks.SaveCheckpointHook(checkpoint_dir='./ckpt')
]
logger.info("Training start", ranks=[0])
trainer.fit(train_dataloader=train_dataloader, epochs=gpc.config.NUM_EPOCHS, hooks=hook_list, display_progress=True)
if __name__ == '__main__':
train_gpt()

View File

@ -14,9 +14,7 @@ from colossalai.core import global_context as gpc
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.nn import Accuracy, CrossEntropyLoss from colossalai.nn import Accuracy, CrossEntropyLoss
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
from colossalai.trainer import Trainer from colossalai.trainer import Trainer, hooks
from colossalai.trainer.hooks import (AccuracyHook, LogMemoryByEpochHook, LogMetricByEpochHook, LogMetricByStepHook,
LogTimingByEpochHook, LossHook, LRSchedulerHook, ThroughputHook)
from colossalai.utils import MultiTimer from colossalai.utils import MultiTimer
from model_zoo.vit import vit_small_patch16_224 from model_zoo.vit import vit_small_patch16_224
from nvidia.dali import types from nvidia.dali import types
@ -185,22 +183,22 @@ def train_imagenet():
trainer = Trainer(engine=engine, logger=logger, timer=timer) trainer = Trainer(engine=engine, logger=logger, timer=timer)
logger.info("Trainer is built", ranks=[0]) logger.info("Trainer is built", ranks=[0])
hooks = [ hook_list = [
LogMetricByEpochHook(logger=logger), hooks.LogMetricByEpochHook(logger=logger),
LogMetricByStepHook(), hooks.LogMetricByStepHook(),
# LogTimingByEpochHook(timer=timer, logger=logger), # hooks.LogTimingByEpochHook(timer=timer, logger=logger),
# LogMemoryByEpochHook(logger=logger), # hooks.LogMemoryByEpochHook(logger=logger),
AccuracyHook(accuracy_func=Accuracy()), hooks.AccuracyHook(accuracy_func=Accuracy()),
LossHook(), hooks.LossHook(),
ThroughputHook(), hooks.ThroughputHook(),
LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=True) hooks.LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=True)
] ]
logger.info("Train start", ranks=[0]) logger.info("Train start", ranks=[0])
trainer.fit(train_dataloader=train_dataloader, trainer.fit(train_dataloader=train_dataloader,
test_dataloader=test_dataloader, test_dataloader=test_dataloader,
epochs=gpc.config.NUM_EPOCHS, epochs=gpc.config.NUM_EPOCHS,
hooks=hooks, hooks=hook_list,
display_progress=True, display_progress=True,
test_interval=1) test_interval=1)

View File

@ -14,9 +14,7 @@ from colossalai.core import global_context as gpc
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.nn import Accuracy, CrossEntropyLoss from colossalai.nn import Accuracy, CrossEntropyLoss
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
from colossalai.trainer import Trainer from colossalai.trainer import Trainer, hooks
from colossalai.trainer.hooks import (AccuracyHook, LogMemoryByEpochHook, LogMetricByEpochHook, LogMetricByStepHook,
LogTimingByEpochHook, LossHook, LRSchedulerHook, ThroughputHook)
from colossalai.utils import MultiTimer from colossalai.utils import MultiTimer
from model_zoo.vit import vit_small_patch16_224 from model_zoo.vit import vit_small_patch16_224
from nvidia.dali import types from nvidia.dali import types
@ -185,22 +183,22 @@ def train_imagenet():
trainer = Trainer(engine=engine, logger=logger, timer=timer) trainer = Trainer(engine=engine, logger=logger, timer=timer)
logger.info("Trainer is built", ranks=[0]) logger.info("Trainer is built", ranks=[0])
hooks = [ hook_list = [
LogMetricByEpochHook(logger=logger), hooks.LogMetricByEpochHook(logger=logger),
LogMetricByStepHook(), hooks.LogMetricByStepHook(),
# LogTimingByEpochHook(timer=timer, logger=logger), # hooks.LogTimingByEpochHook(timer=timer, logger=logger),
# LogMemoryByEpochHook(logger=logger), # hooks.LogMemoryByEpochHook(logger=logger),
AccuracyHook(accuracy_func=Accuracy()), hooks.AccuracyHook(accuracy_func=Accuracy()),
LossHook(), hooks.LossHook(),
ThroughputHook(), hooks.ThroughputHook(),
LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=True) hooks.LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=True)
] ]
logger.info("Train start", ranks=[0]) logger.info("Train start", ranks=[0])
trainer.fit(train_dataloader=train_dataloader, trainer.fit(train_dataloader=train_dataloader,
test_dataloader=test_dataloader, test_dataloader=test_dataloader,
epochs=gpc.config.NUM_EPOCHS, epochs=gpc.config.NUM_EPOCHS,
hooks=hooks, hooks=hook_list,
display_progress=True, display_progress=True,
test_interval=1) test_interval=1)

View File

@ -0,0 +1 @@
from .gpt import *

284
model_zoo/gpt/gpt.py Normal file
View File

@ -0,0 +1,284 @@
import math
from typing import Callable
import torch
from colossalai import nn as col_nn
from colossalai.nn.layer.utils import CheckpointModule
from colossalai.registry import LAYERS, MODELS, LOSSES
from colossalai.utils import get_current_device
from torch import dtype, nn
__all__ = ['GPT', 'GPTLMLoss', 'gpt2_small', 'gpt2_medium', 'gpt2_large', 'gpt2_xl', 'gpt3']
@LAYERS.register_module
class GPTEmbedding(nn.Module):
def __init__(self,
embedding_dim: int,
vocab_size: int,
max_position_embeddings: int,
num_tokentypes: int = 0,
padding_idx: int = 0,
dropout: float = 0.,
dtype: dtype = None) -> None:
super().__init__()
self.word_embeddings = col_nn.Embedding(vocab_size, embedding_dim, padding_idx=padding_idx, dtype=dtype)
self.position_embeddings = col_nn.Embedding(max_position_embeddings, embedding_dim, dtype=dtype)
if num_tokentypes > 0:
self.tokentype_embeddings = col_nn.Embedding(num_tokentypes, embedding_dim, dtype=dtype)
else:
self.tokentype_embeddings = None
self.dropout = col_nn.Dropout(dropout)
@property
def word_embedding_weight(self):
return self.word_embeddings.weight
def forward(self, input_ids, position_ids=None, tokentype_ids=None):
seq_length = input_ids.size(1)
if position_ids is None:
position_ids = torch.arange(seq_length, dtype=torch.long, device=get_current_device()).unsqueeze(0)
x = self.word_embeddings(input_ids) + self.position_embeddings(position_ids)
if self.tokentype_embeddings is not None and tokentype_ids is not None:
x = x + self.tokentype_embeddings(tokentype_ids)
x = self.dropout(x)
return x
@LAYERS.register_module
class GPTSelfAttention(nn.Module):
def __init__(self,
dim: int,
num_heads: int,
attention_dropout: float,
dropout: float,
bias: bool = True,
dtype: dtype = None) -> None:
super().__init__()
self.attention_head_size = dim // num_heads
self.query_key_value = col_nn.Linear(dim, 3 * dim, dtype=dtype, bias=bias)
self.attention_dropout = col_nn.Dropout(attention_dropout)
self.dense = col_nn.Linear(dim, dim, dtype=dtype, bias=True)
self.dropout = col_nn.Dropout(dropout)
self.softmax = nn.Softmax(dim=-1)
def forward(self, x, attention_mask=None):
qkv = self.query_key_value(x)
all_head_size = qkv.shape[-1] // 3
num_attention_heads = all_head_size // self.attention_head_size
new_qkv_shape = qkv.shape[:-1] + \
(num_attention_heads, 3 * self.attention_head_size)
qkv = qkv.view(new_qkv_shape)
qkv = qkv.permute((0, 2, 1, 3))
q, k, v = torch.chunk(qkv, 3, dim=-1)
x = torch.matmul(q, k.transpose(-1, -2))
x = x / math.sqrt(self.attention_head_size)
# causal mask
q_len, k_len = q.size(-2), k.size(-2)
causal_mask = torch.tril(torch.ones((q_len, k_len), dtype=torch.uint8,
device=get_current_device())).view(1, 1, q_len, k_len).bool()
x = torch.where(causal_mask, x, torch.tensor(-1e4, dtype=x.dtype, device=get_current_device()))
if attention_mask is not None:
x = x + attention_mask
x = self.softmax(x)
x = self.attention_dropout(x)
x = torch.matmul(x, v)
x = x.transpose(1, 2)
new_context_layer_shape = x.size()[:-2] + (all_head_size, )
x = x.reshape(new_context_layer_shape)
x = self.dense(x)
x = self.dropout(x)
return x
@LAYERS.register_module
class GPTMLP(nn.Module):
def __init__(self,
dim: int,
mlp_ratio: int,
activation: Callable,
dropout: float,
dtype: dtype = None,
bias: bool = True):
super().__init__()
self.dense_1 = col_nn.Linear(dim, mlp_ratio * dim, dtype=dtype, bias=bias)
self.activation = activation
self.dense_2 = col_nn.Linear(mlp_ratio * dim, dim, dtype=dtype, bias=bias)
self.dropout = col_nn.Dropout(dropout)
def forward(self, x):
x = self.dense_1(x)
x = self.activation(x)
x = self.dense_2(x)
x = self.dropout(x)
return x
@LAYERS.register_module
class GPTBlock(CheckpointModule):
def __init__(self,
dim: int,
num_heads: int,
mlp_ratio: int,
activation: Callable,
attention_dropout: float = 0.,
dropout: float = 0.,
dtype: dtype = None,
bias: bool = True,
checkpoint: bool = False):
super().__init__()
self.norm1 = col_nn.LayerNorm(normalized_shape=dim, eps=1e-6, dtype=dtype)
self.attn = GPTSelfAttention(dim=dim,
num_heads=num_heads,
attention_dropout=attention_dropout,
dropout=dropout,
bias=bias,
dtype=dtype)
self.norm2 = col_nn.LayerNorm(normalized_shape=dim, eps=1e-6, dtype=dtype)
self.mlp = GPTMLP(dim=dim, mlp_ratio=mlp_ratio, activation=activation, dropout=dropout, dtype=dtype, bias=bias)
def _forward(self, x, attention_mask=None):
x = x + self.attn(self.norm1(x), attention_mask)
x = x + self.mlp(self.norm2(x))
return x, attention_mask
@LAYERS.register_module
class GPTLMHead(nn.Module):
def __init__(self,
dim: int,
vocab_size: int,
word_embeeding_weight: nn.Parameter = None,
bias: bool = False,
dtype: dtype = None) -> None:
super().__init__()
self.dense = col_nn.Classifier(dim, vocab_size, word_embeeding_weight, bias=bias, dtype=dtype)
def forward(self, x):
x = self.dense(x)
return x
@LOSSES.register_module
class GPTLMLoss(nn.Module):
def __init__(self):
super().__init__()
self.loss = col_nn.CrossEntropyLoss()
def forward(self, logits, labels):
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
return self.loss(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
@MODELS.register_module
class GPT(nn.Module):
def __init__(self,
vocab_size: int = 50304,
max_position_embeddings: int = 1024,
dim: int = 768,
num_heads: int = 12,
depth: int = 12,
mlp_ratio: int = 4,
dropout: float = 0.1,
embedding_dropout: float = 0.1,
attention_dropout: float = 0.1,
layernorm_epsilon: float = 1e-5,
activation: Callable = nn.functional.gelu,
checkpoint: bool = False,
dtype: dtype = None,
bias: bool = True,
padding_idx: int = 0) -> None:
super().__init__()
self.dtype = dtype
self.embed = GPTEmbedding(embedding_dim=dim,
vocab_size=vocab_size,
max_position_embeddings=max_position_embeddings,
padding_idx=padding_idx,
dropout=embedding_dropout,
dtype=dtype)
self.blocks = nn.ModuleList([
GPTBlock(
dim=dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
activation=activation,
attention_dropout=attention_dropout,
dropout=dropout,
dtype=dtype,
bias=bias,
checkpoint=checkpoint,
) for _ in range(depth)
])
self.norm = col_nn.LayerNorm(normalized_shape=dim, eps=layernorm_epsilon, dtype=dtype)
self.head = GPTLMHead(dim=dim,
vocab_size=vocab_size,
word_embeeding_weight=self.embed.word_embedding_weight,
bias=bias,
dtype=dtype)
def forward(self, input_ids, attention_mask=None):
# We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# Adapted from huggingface
if attention_mask is not None:
batch_size = input_ids.shape[0]
attention_mask = attention_mask.view(batch_size, -1)
attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
attention_mask = (1.0 - attention_mask) * -10000.0
x = self.embed(input_ids)
for block in self.blocks:
x, attention_mask = block(x, attention_mask)
x = self.head(self.norm(x))
return x
def _create_gpt_model(**model_kwargs):
model = GPT(**model_kwargs)
return model
@MODELS.register_module
def gpt2_small(**kwargs):
model_kwargs = dict(dim=768, depth=12, num_heads=12, **kwargs)
return _create_gpt_model(**model_kwargs)
@MODELS.register_module
def gpt2_medium(**kwargs):
model_kwargs = dict(dim=1024, depth=24, num_heads=16, **kwargs)
return _create_gpt_model(**model_kwargs)
@MODELS.register_module
def gpt2_large(**kwargs):
model_kwargs = dict(dim=1280, depth=36, num_heads=20, **kwargs)
return _create_gpt_model(**model_kwargs)
@MODELS.register_module
def gpt2_xl(**kwargs):
model_kwargs = dict(dim=1600, depth=48, num_heads=25, **kwargs)
return _create_gpt_model(**model_kwargs)
@MODELS.register_module
def gpt3(**kwargs):
model_kwargs = dict(dim=12288, max_position_embeddings=2048, depth=96, num_heads=96, **kwargs)
return _create_gpt_model(**model_kwargs)

View File

@ -89,7 +89,7 @@ class ViTEmbedding(nn.Module):
@LAYERS.register_module @LAYERS.register_module
class ViTSelfAttention(CheckpointModule): class ViTSelfAttention(nn.Module):
def __init__(self, def __init__(self,
dim: int, dim: int,
num_heads: int, num_heads: int,
@ -97,9 +97,8 @@ class ViTSelfAttention(CheckpointModule):
dropout: float, dropout: float,
bias: bool = True, bias: bool = True,
dtype: dtype = None, dtype: dtype = None,
checkpoint: bool = False,
init_method: str = 'torch'): init_method: str = 'torch'):
super().__init__(checkpoint) super().__init__()
self.attention_head_size = dim // num_heads self.attention_head_size = dim // num_heads
self.query_key_value = col_nn.Linear(dim, self.query_key_value = col_nn.Linear(dim,
3 * dim, 3 * dim,
@ -111,7 +110,7 @@ class ViTSelfAttention(CheckpointModule):
self.dropout = col_nn.Dropout(dropout) self.dropout = col_nn.Dropout(dropout)
self.softmax = nn.Softmax(dim=-1) self.softmax = nn.Softmax(dim=-1)
def _forward(self, x): def forward(self, x):
qkv = self.query_key_value(x) qkv = self.query_key_value(x)
all_head_size = qkv.shape[-1] // 3 all_head_size = qkv.shape[-1] // 3
num_attention_heads = all_head_size // self.attention_head_size num_attention_heads = all_head_size // self.attention_head_size
@ -138,7 +137,7 @@ class ViTSelfAttention(CheckpointModule):
@LAYERS.register_module @LAYERS.register_module
class ViTMLP(CheckpointModule): class ViTMLP(nn.Module):
def __init__(self, def __init__(self,
dim: int, dim: int,
mlp_ratio: int, mlp_ratio: int,
@ -146,9 +145,8 @@ class ViTMLP(CheckpointModule):
dropout: float, dropout: float,
dtype: dtype = None, dtype: dtype = None,
bias: bool = True, bias: bool = True,
checkpoint: bool = False,
init_method: str = 'torch'): init_method: str = 'torch'):
super().__init__(checkpoint) super().__init__()
self.dense_1 = col_nn.Linear(dim, self.dense_1 = col_nn.Linear(dim,
mlp_ratio * dim, mlp_ratio * dim,
dtype=dtype, dtype=dtype,
@ -163,7 +161,7 @@ class ViTMLP(CheckpointModule):
**_init_rules[init_method]['transformer']) **_init_rules[init_method]['transformer'])
self.dropout_2 = col_nn.Dropout(dropout) self.dropout_2 = col_nn.Dropout(dropout)
def _forward(self, x): def forward(self, x):
x = self.dense_1(x) x = self.dense_1(x)
x = self.activation(x) x = self.activation(x)
x = self.dropout_1(x) x = self.dropout_1(x)
@ -192,7 +190,7 @@ class ViTHead(nn.Module):
self.representation = None self.representation = None
representation_size = dim representation_size = dim
self.linear = col_nn.Classifier(representation_size, self.dense = col_nn.Classifier(representation_size,
num_classes, num_classes,
dtype=dtype, dtype=dtype,
bias=bias, bias=bias,
@ -202,12 +200,12 @@ class ViTHead(nn.Module):
x = x[:, 0] x = x[:, 0]
if self.representation is not None: if self.representation is not None:
x = self.representation(x) x = self.representation(x)
x = self.linear(x) x = self.dense(x)
return x return x
@LAYERS.register_module @LAYERS.register_module
class ViTBlock(nn.Module): class ViTBlock(CheckpointModule):
def __init__(self, def __init__(self,
dim: int, dim: int,
num_heads: int, num_heads: int,
@ -216,32 +214,31 @@ class ViTBlock(nn.Module):
attention_dropout: float = 0., attention_dropout: float = 0.,
dropout: float = 0., dropout: float = 0.,
drop_path: float = 0., drop_path: float = 0.,
layernorm_epsilon: float = 1e-6,
dtype: dtype = None, dtype: dtype = None,
bias: bool = True, bias: bool = True,
checkpoint: bool = False, checkpoint: bool = False,
init_method: str = 'torch'): init_method: str = 'torch'):
super().__init__() super().__init__(checkpoint)
self.norm1 = col_nn.LayerNorm(normalized_shape=dim, eps=1e-6, dtype=dtype) self.norm1 = col_nn.LayerNorm(normalized_shape=dim, eps=layernorm_epsilon, dtype=dtype)
self.attn = ViTSelfAttention(dim=dim, self.attn = ViTSelfAttention(dim=dim,
num_heads=num_heads, num_heads=num_heads,
attention_dropout=attention_dropout, attention_dropout=attention_dropout,
dropout=dropout, dropout=dropout,
bias=bias, bias=bias,
dtype=dtype, dtype=dtype,
checkpoint=checkpoint,
init_method=init_method) init_method=init_method)
self.drop_path = col_nn.DropPath(drop_path) if drop_path > 0. else nn.Identity() self.drop_path = col_nn.DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = col_nn.LayerNorm(normalized_shape=dim, eps=1e-6, dtype=dtype) self.norm2 = col_nn.LayerNorm(normalized_shape=dim, eps=layernorm_epsilon, dtype=dtype)
self.mlp = ViTMLP(dim=dim, self.mlp = ViTMLP(dim=dim,
mlp_ratio=mlp_ratio, mlp_ratio=mlp_ratio,
activation=activation, activation=activation,
dropout=dropout, dropout=dropout,
dtype=dtype, dtype=dtype,
bias=bias, bias=bias,
checkpoint=checkpoint,
init_method=init_method) init_method=init_method)
def forward(self, x): def _forward(self, x):
x = x + self.drop_path(self.attn(self.norm1(x))) x = x + self.drop_path(self.attn(self.norm1(x)))
x = x + self.drop_path(self.mlp(self.norm2(x))) x = x + self.drop_path(self.mlp(self.norm2(x)))
return x return x
@ -261,6 +258,7 @@ class VisionTransformer(nn.Module):
attention_dropout: float = 0., attention_dropout: float = 0.,
dropout: float = 0.1, dropout: float = 0.1,
drop_path: float = 0., drop_path: float = 0.,
layernorm_epsilon: float = 1e-6,
activation: Callable = nn.functional.gelu, activation: Callable = nn.functional.gelu,
representation_size: int = None, representation_size: int = None,
dtype: dtype = None, dtype: dtype = None,
@ -295,7 +293,7 @@ class VisionTransformer(nn.Module):
) for i in range(depth) ) for i in range(depth)
] ]
norm = col_nn.LayerNorm(normalized_shape=dim, eps=1e-6, dtype=dtype) norm = col_nn.LayerNorm(normalized_shape=dim, eps=layernorm_epsilon, dtype=dtype)
head = ViTHead(dim=dim, head = ViTHead(dim=dim,
num_classes=num_classes, num_classes=num_classes,