diff --git a/benchmark/cifar/train.py b/benchmark/cifar/train.py index 4ffa22968..38cd6e29b 100644 --- a/benchmark/cifar/train.py +++ b/benchmark/cifar/train.py @@ -11,12 +11,7 @@ from colossalai.core import global_context as gpc from colossalai.logging import get_dist_logger from colossalai.nn import Accuracy, CrossEntropyLoss from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR -from colossalai.trainer import Trainer -from colossalai.trainer.hooks import (AccuracyHook, LogMemoryByEpochHook, - LogMetricByEpochHook, - LogMetricByStepHook, - LogTimingByEpochHook, LossHook, - LRSchedulerHook, ThroughputHook) +from colossalai.trainer import Trainer, hooks from colossalai.utils import MultiTimer, get_dataloader from model_zoo.vit import vit_lite_depth7_patch4_32 from torchvision import transforms @@ -100,22 +95,22 @@ def train_cifar(): trainer = Trainer(engine=engine, logger=logger, timer=timer) logger.info("Trainer is built", ranks=[0]) - hooks = [ - LogMetricByEpochHook(logger=logger), - LogMetricByStepHook(), - # LogTimingByEpochHook(timer=timer, logger=logger), - # LogMemoryByEpochHook(logger=logger), - AccuracyHook(accuracy_func=Accuracy()), - LossHook(), - ThroughputHook(), - LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=False) + hook_list = [ + hooks.LogMetricByEpochHook(logger=logger), + hooks.LogMetricByStepHook(), + # hooks.LogTimingByEpochHook(timer=timer, logger=logger), + # hooks.LogMemoryByEpochHook(logger=logger), + hooks.AccuracyHook(accuracy_func=Accuracy()), + hooks.LossHook(), + hooks.ThroughputHook(), + hooks.LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=False) ] logger.info("Train start", ranks=[0]) trainer.fit(train_dataloader=train_dataloader, test_dataloader=test_dataloader, epochs=gpc.config.NUM_EPOCHS, - hooks=hooks, + hooks=hook_list, display_progress=True, test_interval=1) diff --git a/benchmark/gpt2/configs/gpt2_1d.py b/benchmark/gpt2/configs/gpt2_1d.py new file mode 100644 index 000000000..f9a659b83 --- /dev/null +++ b/benchmark/gpt2/configs/gpt2_1d.py @@ -0,0 +1,29 @@ +from colossalai.amp import AMP_TYPE + +VOCAB_SIZE = 50304 +SEQ_LENGTH = 1024 + +TOTAL_BATCH_SIZE = 256 +LEARNING_RATE = 0.00015 +WEIGHT_DECAY = 1e-2 + +TENSOR_PARALLEL_SIZE = 2 +TENSOR_PARALLEL_MODE = '1d' + +NUM_EPOCHS = 60 +WARMUP_EPOCHS = int(NUM_EPOCHS * 0.36) + +parallel = dict( + pipeline=1, + tensor=dict(mode=TENSOR_PARALLEL_MODE, size=TENSOR_PARALLEL_SIZE), +) + +fp16 = dict(mode=AMP_TYPE.TORCH, ) + +gradient_accumulation = 2 + +BATCH_SIZE = TOTAL_BATCH_SIZE // gradient_accumulation + +clip_grad_norm = 1.0 + +LOG_PATH = f"./gpt2_{TENSOR_PARALLEL_MODE}_tp{TENSOR_PARALLEL_SIZE}_bs{BATCH_SIZE}_lr{LEARNING_RATE}_accum{gradient_accumulation}_clip_grad{clip_grad_norm}/" diff --git a/benchmark/gpt2/configs/gpt2_2d.py b/benchmark/gpt2/configs/gpt2_2d.py new file mode 100644 index 000000000..5abec73e6 --- /dev/null +++ b/benchmark/gpt2/configs/gpt2_2d.py @@ -0,0 +1,29 @@ +from colossalai.amp import AMP_TYPE + +VOCAB_SIZE = 50304 +SEQ_LENGTH = 1024 + +TOTAL_BATCH_SIZE = 256 +LEARNING_RATE = 0.00015 +WEIGHT_DECAY = 1e-2 + +TENSOR_PARALLEL_SIZE = 4 +TENSOR_PARALLEL_MODE = '2d' + +NUM_EPOCHS = 60 +WARMUP_EPOCHS = int(NUM_EPOCHS * 0.36) + +parallel = dict( + pipeline=1, + tensor=dict(mode=TENSOR_PARALLEL_MODE, size=TENSOR_PARALLEL_SIZE), +) + +fp16 = dict(mode=AMP_TYPE.TORCH, ) + +gradient_accumulation = 1 + +BATCH_SIZE = TOTAL_BATCH_SIZE // gradient_accumulation + +clip_grad_norm = 1.0 + +LOG_PATH = f"./gpt2_{TENSOR_PARALLEL_MODE}_tp{TENSOR_PARALLEL_SIZE}_bs{BATCH_SIZE}_lr{LEARNING_RATE}_accum{gradient_accumulation}_clip_grad{clip_grad_norm}/" diff --git a/benchmark/gpt2/configs/gpt2_2p5d.py b/benchmark/gpt2/configs/gpt2_2p5d.py new file mode 100644 index 000000000..33ea4411e --- /dev/null +++ b/benchmark/gpt2/configs/gpt2_2p5d.py @@ -0,0 +1,30 @@ +from colossalai.amp import AMP_TYPE + +VOCAB_SIZE = 50304 +SEQ_LENGTH = 1024 + +TOTAL_BATCH_SIZE = 256 +LEARNING_RATE = 0.00015 +WEIGHT_DECAY = 1e-2 + +TENSOR_PARALLEL_SIZE = 4 +DEPTH = 1 +TENSOR_PARALLEL_MODE = '2.5d' + +NUM_EPOCHS = 60 +WARMUP_EPOCHS = int(NUM_EPOCHS * 0.36) + +parallel = dict( + pipeline=1, + tensor=dict(mode=TENSOR_PARALLEL_MODE, size=TENSOR_PARALLEL_SIZE, depth=DEPTH), +) + +fp16 = dict(mode=AMP_TYPE.TORCH, ) + +gradient_accumulation = 1 + +BATCH_SIZE = TOTAL_BATCH_SIZE // gradient_accumulation + +clip_grad_norm = 1.0 + +LOG_PATH = f"./gpt2_{TENSOR_PARALLEL_MODE}_tp{TENSOR_PARALLEL_SIZE}_bs{BATCH_SIZE}_lr{LEARNING_RATE}_accum{gradient_accumulation}_clip_grad{clip_grad_norm}/" diff --git a/benchmark/gpt2/configs/gpt2_3d.py b/benchmark/gpt2/configs/gpt2_3d.py new file mode 100644 index 000000000..9f8728d29 --- /dev/null +++ b/benchmark/gpt2/configs/gpt2_3d.py @@ -0,0 +1,29 @@ +from colossalai.amp import AMP_TYPE + +VOCAB_SIZE = 50304 +SEQ_LENGTH = 1024 + +TOTAL_BATCH_SIZE = 256 +LEARNING_RATE = 0.00015 +WEIGHT_DECAY = 1e-2 + +TENSOR_PARALLEL_SIZE = 8 +TENSOR_PARALLEL_MODE = '3d' + +NUM_EPOCHS = 60 +WARMUP_EPOCHS = int(NUM_EPOCHS * 0.36) + +parallel = dict( + pipeline=1, + tensor=dict(mode=TENSOR_PARALLEL_MODE, size=TENSOR_PARALLEL_SIZE), +) + +fp16 = dict(mode=AMP_TYPE.TORCH, ) + +gradient_accumulation = 1 + +BATCH_SIZE = TOTAL_BATCH_SIZE // gradient_accumulation + +clip_grad_norm = 1.0 + +LOG_PATH = f"./gpt2_{TENSOR_PARALLEL_MODE}_tp{TENSOR_PARALLEL_SIZE}_bs{BATCH_SIZE}_lr{LEARNING_RATE}_accum{gradient_accumulation}_clip_grad{clip_grad_norm}/" diff --git a/benchmark/gpt2/configs/gpt2_vanilla.py b/benchmark/gpt2/configs/gpt2_vanilla.py new file mode 100644 index 000000000..b450cd048 --- /dev/null +++ b/benchmark/gpt2/configs/gpt2_vanilla.py @@ -0,0 +1,29 @@ +from colossalai.amp import AMP_TYPE + +VOCAB_SIZE = 50304 +SEQ_LENGTH = 1024 + +TOTAL_BATCH_SIZE = 256 +LEARNING_RATE = 0.00015 +WEIGHT_DECAY = 1e-2 + +TENSOR_PARALLEL_SIZE = 1 +TENSOR_PARALLEL_MODE = None + +NUM_EPOCHS = 60 +WARMUP_EPOCHS = int(NUM_EPOCHS * 0.36) + +parallel = dict( + pipeline=1, + tensor=dict(mode=TENSOR_PARALLEL_MODE, size=TENSOR_PARALLEL_SIZE), +) + +fp16 = dict(mode=AMP_TYPE.TORCH, ) + +gradient_accumulation = 1 + +BATCH_SIZE = TOTAL_BATCH_SIZE // gradient_accumulation + +clip_grad_norm = 1.0 + +LOG_PATH = f"./gpt2_{TENSOR_PARALLEL_MODE}_tp{TENSOR_PARALLEL_SIZE}_bs{BATCH_SIZE}_lr{LEARNING_RATE}_accum{gradient_accumulation}_clip_grad{clip_grad_norm}/" diff --git a/benchmark/gpt2/data.py b/benchmark/gpt2/data.py new file mode 100644 index 000000000..d6fdfba78 --- /dev/null +++ b/benchmark/gpt2/data.py @@ -0,0 +1,37 @@ +import json +import os + +import torch +from colossalai.registry import DATASETS +from torch.utils.data import Dataset +from transformers import GPT2Tokenizer + + +@DATASETS.register_module +class WebtextDataset(Dataset): + def __init__(self, path, seq_len=1024) -> None: + super().__init__() + root = os.path.dirname(path) + encoded_data_cache_path = os.path.join(root, f'gpt_webtext_{seq_len}.pt') + if os.path.isfile(encoded_data_cache_path): + seq_len_, data, attention_mask = torch.load(encoded_data_cache_path) + if seq_len_ == seq_len: + self.data = data + self.attention_mask = attention_mask + return + raw_data = [] + with open(path) as f: + for line in f.readlines(): + raw_data.append(json.loads(line)['text']) + tokenizer = GPT2Tokenizer.from_pretrained('gpt2') + tokenizer.pad_token = tokenizer.unk_token + encoded_data = tokenizer(raw_data, padding=True, truncation=True, max_length=seq_len, return_tensors='pt') + self.data = encoded_data['input_ids'] + self.attention_mask = encoded_data['attention_mask'] + torch.save((seq_len, self.data, self.attention_mask), encoded_data_cache_path) + + def __len__(self): + return len(self.data) + + def __getitem__(self, index): + return (self.data[index], self.attention_mask[index]), self.data[index] diff --git a/benchmark/gpt2/train.py b/benchmark/gpt2/train.py new file mode 100644 index 000000000..664a5a206 --- /dev/null +++ b/benchmark/gpt2/train.py @@ -0,0 +1,105 @@ +import contextlib +import os + +import colossalai +import torch +from colossalai.core import global_context as gpc +from colossalai.engine.schedule import (InterleavedPipelineSchedule, PipelineSchedule) +from colossalai.logging import get_dist_logger +from colossalai.nn import CosineAnnealingWarmupLR +from colossalai.trainer import Trainer, hooks +from colossalai.utils import MultiTimer, get_dataloader +from colossalai.zero import zero3_model_context +from model_zoo.gpt import GPTLMLoss, gpt2_small, gpt2_medium, gpt2_large, gpt2_xl + +from data import WebtextDataset + + +def train_gpt(): + args = colossalai.get_default_parser().parse_args() + # standard launch + # colossalai.launch(config=args.config, + # rank=args.rank, + # world_size=args.world_size, + # local_rank=args.local_rank, + # host=args.host, + # port=args.port) + + # launch from torchrun + colossalai.launch_from_torch(config=args.config) + + logger = get_dist_logger() + if hasattr(gpc.config, 'LOG_PATH'): + if gpc.get_global_rank() == 0: + log_path = gpc.config.LOG_PATH + if not os.path.exists(log_path): + os.mkdir(log_path) + logger.log_to_file(log_path) + + train_dataset = WebtextDataset(os.environ['DATA'], seq_len=gpc.config.SEQ_LENGTH) + train_dataloader = get_dataloader(train_dataset, + seed=42, + batch_size=gpc.config.BATCH_SIZE // gpc.data_parallel_size, + pin_memory=True, + shuffle=True, + drop_last=True) + logger.info(f'Loaded {len(train_dataset)}/{len(train_dataloader)} samples/batches', ranks=[0]) + + # zero3 under test + # use_zero3 = hasattr(gpc.config, 'zero') and gpc.config.zero.level == 3 + # cm = zero3_model_context() if use_zero3 else contextlib.nullcontext() + # with cm: + # model = gpc.config.model.pop('type')(**gpc.config.model) + + model = gpt2_medium(vocab_size=gpc.config.VOCAB_SIZE, + max_position_embeddings=gpc.config.SEQ_LENGTH, + checkpoint=True) + + criterion = GPTLMLoss() + + optimizer = torch.optim.Adam(model.parameters(), lr=0.00015, weight_decay=1e-2) + + steps_per_epoch = len(train_dataloader) // gpc.config.gradient_accumulation + + lr_scheduler = CosineAnnealingWarmupLR(optimizer=optimizer, + total_steps=gpc.config.NUM_EPOCHS * steps_per_epoch, + warmup_steps=gpc.config.WARMUP_EPOCHS * steps_per_epoch, + eta_min=1e-5) + + engine, train_dataloader, _, lr_scheduler = colossalai.initialize(model=model, + optimizer=optimizer, + criterion=criterion, + train_dataloader=train_dataloader, + lr_scheduler=lr_scheduler) + + # pipeline under test + # num_model_chunks = getattr(gpc.config.model, 'num_chunks', 1) + # if num_model_chunks > 1: + # logger.info('Build InterleavedPipelineSchedule', ranks=[0]) + # schedule = InterleavedPipelineSchedule(gpc.config.NUM_MICRO_BATCHES, num_model_chunks) + # else: + # logger.info('Build PipelineSchedule', ranks=[0]) + # schedule = PipelineSchedule(gpc.config.NUM_MICRO_BATCHES) + + timer = MultiTimer() + + trainer = Trainer(engine=engine, logger=logger, timer=timer) + + hook_list = [ + hooks.LogMetricByEpochHook(logger=logger), + hooks.LogMetricByStepHook(), + hooks.LossHook(), + hooks.ThroughputHook(), + hooks.LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=False), + # hooks.TensorboardHook(log_dir='./tb_logs', ranks=[0]), + # hooks.LogMemoryByEpochHook(logger), + # hooks.LogTimingByEpochHook(timer, logger), + # hooks.SaveCheckpointHook(checkpoint_dir='./ckpt') + ] + + logger.info("Training start", ranks=[0]) + trainer.fit(train_dataloader=train_dataloader, epochs=gpc.config.NUM_EPOCHS, hooks=hook_list, display_progress=True) + + +if __name__ == '__main__': + train_gpt() diff --git a/benchmark/imagenet100/train.py b/benchmark/imagenet100/train.py index 58ad3b15e..af06ec452 100644 --- a/benchmark/imagenet100/train.py +++ b/benchmark/imagenet100/train.py @@ -14,9 +14,7 @@ from colossalai.core import global_context as gpc from colossalai.logging import get_dist_logger from colossalai.nn import Accuracy, CrossEntropyLoss from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR -from colossalai.trainer import Trainer -from colossalai.trainer.hooks import (AccuracyHook, LogMemoryByEpochHook, LogMetricByEpochHook, LogMetricByStepHook, - LogTimingByEpochHook, LossHook, LRSchedulerHook, ThroughputHook) +from colossalai.trainer import Trainer, hooks from colossalai.utils import MultiTimer from model_zoo.vit import vit_small_patch16_224 from nvidia.dali import types @@ -185,22 +183,22 @@ def train_imagenet(): trainer = Trainer(engine=engine, logger=logger, timer=timer) logger.info("Trainer is built", ranks=[0]) - hooks = [ - LogMetricByEpochHook(logger=logger), - LogMetricByStepHook(), - # LogTimingByEpochHook(timer=timer, logger=logger), - # LogMemoryByEpochHook(logger=logger), - AccuracyHook(accuracy_func=Accuracy()), - LossHook(), - ThroughputHook(), - LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=True) + hook_list = [ + hooks.LogMetricByEpochHook(logger=logger), + hooks.LogMetricByStepHook(), + # hooks.LogTimingByEpochHook(timer=timer, logger=logger), + # hooks.LogMemoryByEpochHook(logger=logger), + hooks.AccuracyHook(accuracy_func=Accuracy()), + hooks.LossHook(), + hooks.ThroughputHook(), + hooks.LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=True) ] logger.info("Train start", ranks=[0]) trainer.fit(train_dataloader=train_dataloader, test_dataloader=test_dataloader, epochs=gpc.config.NUM_EPOCHS, - hooks=hooks, + hooks=hook_list, display_progress=True, test_interval=1) diff --git a/benchmark/imagenet1k/train.py b/benchmark/imagenet1k/train.py index d9b9ade99..4a77280df 100644 --- a/benchmark/imagenet1k/train.py +++ b/benchmark/imagenet1k/train.py @@ -14,9 +14,7 @@ from colossalai.core import global_context as gpc from colossalai.logging import get_dist_logger from colossalai.nn import Accuracy, CrossEntropyLoss from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR -from colossalai.trainer import Trainer -from colossalai.trainer.hooks import (AccuracyHook, LogMemoryByEpochHook, LogMetricByEpochHook, LogMetricByStepHook, - LogTimingByEpochHook, LossHook, LRSchedulerHook, ThroughputHook) +from colossalai.trainer import Trainer, hooks from colossalai.utils import MultiTimer from model_zoo.vit import vit_small_patch16_224 from nvidia.dali import types @@ -185,22 +183,22 @@ def train_imagenet(): trainer = Trainer(engine=engine, logger=logger, timer=timer) logger.info("Trainer is built", ranks=[0]) - hooks = [ - LogMetricByEpochHook(logger=logger), - LogMetricByStepHook(), - # LogTimingByEpochHook(timer=timer, logger=logger), - # LogMemoryByEpochHook(logger=logger), - AccuracyHook(accuracy_func=Accuracy()), - LossHook(), - ThroughputHook(), - LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=True) + hook_list = [ + hooks.LogMetricByEpochHook(logger=logger), + hooks.LogMetricByStepHook(), + # hooks.LogTimingByEpochHook(timer=timer, logger=logger), + # hooks.LogMemoryByEpochHook(logger=logger), + hooks.AccuracyHook(accuracy_func=Accuracy()), + hooks.LossHook(), + hooks.ThroughputHook(), + hooks.LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=True) ] logger.info("Train start", ranks=[0]) trainer.fit(train_dataloader=train_dataloader, test_dataloader=test_dataloader, epochs=gpc.config.NUM_EPOCHS, - hooks=hooks, + hooks=hook_list, display_progress=True, test_interval=1) diff --git a/model_zoo/gpt/__init__.py b/model_zoo/gpt/__init__.py new file mode 100644 index 000000000..5a20f0f81 --- /dev/null +++ b/model_zoo/gpt/__init__.py @@ -0,0 +1 @@ +from .gpt import * \ No newline at end of file diff --git a/model_zoo/gpt/gpt.py b/model_zoo/gpt/gpt.py new file mode 100644 index 000000000..99095f08c --- /dev/null +++ b/model_zoo/gpt/gpt.py @@ -0,0 +1,284 @@ +import math +from typing import Callable + +import torch +from colossalai import nn as col_nn +from colossalai.nn.layer.utils import CheckpointModule +from colossalai.registry import LAYERS, MODELS, LOSSES +from colossalai.utils import get_current_device +from torch import dtype, nn + +__all__ = ['GPT', 'GPTLMLoss', 'gpt2_small', 'gpt2_medium', 'gpt2_large', 'gpt2_xl', 'gpt3'] + + +@LAYERS.register_module +class GPTEmbedding(nn.Module): + def __init__(self, + embedding_dim: int, + vocab_size: int, + max_position_embeddings: int, + num_tokentypes: int = 0, + padding_idx: int = 0, + dropout: float = 0., + dtype: dtype = None) -> None: + super().__init__() + self.word_embeddings = col_nn.Embedding(vocab_size, embedding_dim, padding_idx=padding_idx, dtype=dtype) + self.position_embeddings = col_nn.Embedding(max_position_embeddings, embedding_dim, dtype=dtype) + if num_tokentypes > 0: + self.tokentype_embeddings = col_nn.Embedding(num_tokentypes, embedding_dim, dtype=dtype) + else: + self.tokentype_embeddings = None + self.dropout = col_nn.Dropout(dropout) + + @property + def word_embedding_weight(self): + return self.word_embeddings.weight + + def forward(self, input_ids, position_ids=None, tokentype_ids=None): + seq_length = input_ids.size(1) + if position_ids is None: + position_ids = torch.arange(seq_length, dtype=torch.long, device=get_current_device()).unsqueeze(0) + x = self.word_embeddings(input_ids) + self.position_embeddings(position_ids) + if self.tokentype_embeddings is not None and tokentype_ids is not None: + x = x + self.tokentype_embeddings(tokentype_ids) + x = self.dropout(x) + return x + + +@LAYERS.register_module +class GPTSelfAttention(nn.Module): + def __init__(self, + dim: int, + num_heads: int, + attention_dropout: float, + dropout: float, + bias: bool = True, + dtype: dtype = None) -> None: + super().__init__() + + self.attention_head_size = dim // num_heads + self.query_key_value = col_nn.Linear(dim, 3 * dim, dtype=dtype, bias=bias) + self.attention_dropout = col_nn.Dropout(attention_dropout) + self.dense = col_nn.Linear(dim, dim, dtype=dtype, bias=True) + self.dropout = col_nn.Dropout(dropout) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, attention_mask=None): + qkv = self.query_key_value(x) + all_head_size = qkv.shape[-1] // 3 + num_attention_heads = all_head_size // self.attention_head_size + new_qkv_shape = qkv.shape[:-1] + \ + (num_attention_heads, 3 * self.attention_head_size) + qkv = qkv.view(new_qkv_shape) + qkv = qkv.permute((0, 2, 1, 3)) + q, k, v = torch.chunk(qkv, 3, dim=-1) + + x = torch.matmul(q, k.transpose(-1, -2)) + x = x / math.sqrt(self.attention_head_size) + + # causal mask + q_len, k_len = q.size(-2), k.size(-2) + causal_mask = torch.tril(torch.ones((q_len, k_len), dtype=torch.uint8, + device=get_current_device())).view(1, 1, q_len, k_len).bool() + x = torch.where(causal_mask, x, torch.tensor(-1e4, dtype=x.dtype, device=get_current_device())) + + if attention_mask is not None: + x = x + attention_mask + x = self.softmax(x) + x = self.attention_dropout(x) + + x = torch.matmul(x, v) + x = x.transpose(1, 2) + new_context_layer_shape = x.size()[:-2] + (all_head_size, ) + x = x.reshape(new_context_layer_shape) + + x = self.dense(x) + x = self.dropout(x) + + return x + + +@LAYERS.register_module +class GPTMLP(nn.Module): + def __init__(self, + dim: int, + mlp_ratio: int, + activation: Callable, + dropout: float, + dtype: dtype = None, + bias: bool = True): + super().__init__() + self.dense_1 = col_nn.Linear(dim, mlp_ratio * dim, dtype=dtype, bias=bias) + self.activation = activation + self.dense_2 = col_nn.Linear(mlp_ratio * dim, dim, dtype=dtype, bias=bias) + self.dropout = col_nn.Dropout(dropout) + + def forward(self, x): + x = self.dense_1(x) + x = self.activation(x) + x = self.dense_2(x) + x = self.dropout(x) + return x + + +@LAYERS.register_module +class GPTBlock(CheckpointModule): + def __init__(self, + dim: int, + num_heads: int, + mlp_ratio: int, + activation: Callable, + attention_dropout: float = 0., + dropout: float = 0., + dtype: dtype = None, + bias: bool = True, + checkpoint: bool = False): + super().__init__() + self.norm1 = col_nn.LayerNorm(normalized_shape=dim, eps=1e-6, dtype=dtype) + self.attn = GPTSelfAttention(dim=dim, + num_heads=num_heads, + attention_dropout=attention_dropout, + dropout=dropout, + bias=bias, + dtype=dtype) + self.norm2 = col_nn.LayerNorm(normalized_shape=dim, eps=1e-6, dtype=dtype) + self.mlp = GPTMLP(dim=dim, mlp_ratio=mlp_ratio, activation=activation, dropout=dropout, dtype=dtype, bias=bias) + + def _forward(self, x, attention_mask=None): + x = x + self.attn(self.norm1(x), attention_mask) + x = x + self.mlp(self.norm2(x)) + return x, attention_mask + + +@LAYERS.register_module +class GPTLMHead(nn.Module): + def __init__(self, + dim: int, + vocab_size: int, + word_embeeding_weight: nn.Parameter = None, + bias: bool = False, + dtype: dtype = None) -> None: + super().__init__() + self.dense = col_nn.Classifier(dim, vocab_size, word_embeeding_weight, bias=bias, dtype=dtype) + + def forward(self, x): + x = self.dense(x) + return x + + +@LOSSES.register_module +class GPTLMLoss(nn.Module): + def __init__(self): + super().__init__() + self.loss = col_nn.CrossEntropyLoss() + + def forward(self, logits, labels): + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + return self.loss(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + + +@MODELS.register_module +class GPT(nn.Module): + def __init__(self, + vocab_size: int = 50304, + max_position_embeddings: int = 1024, + dim: int = 768, + num_heads: int = 12, + depth: int = 12, + mlp_ratio: int = 4, + dropout: float = 0.1, + embedding_dropout: float = 0.1, + attention_dropout: float = 0.1, + layernorm_epsilon: float = 1e-5, + activation: Callable = nn.functional.gelu, + checkpoint: bool = False, + dtype: dtype = None, + bias: bool = True, + padding_idx: int = 0) -> None: + super().__init__() + self.dtype = dtype + self.embed = GPTEmbedding(embedding_dim=dim, + vocab_size=vocab_size, + max_position_embeddings=max_position_embeddings, + padding_idx=padding_idx, + dropout=embedding_dropout, + dtype=dtype) + self.blocks = nn.ModuleList([ + GPTBlock( + dim=dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + activation=activation, + attention_dropout=attention_dropout, + dropout=dropout, + dtype=dtype, + bias=bias, + checkpoint=checkpoint, + ) for _ in range(depth) + ]) + + self.norm = col_nn.LayerNorm(normalized_shape=dim, eps=layernorm_epsilon, dtype=dtype) + + self.head = GPTLMHead(dim=dim, + vocab_size=vocab_size, + word_embeeding_weight=self.embed.word_embedding_weight, + bias=bias, + dtype=dtype) + + def forward(self, input_ids, attention_mask=None): + # We create a 3D attention mask from a 2D tensor mask. + # Sizes are [batch_size, 1, 1, to_seq_length] + # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] + # Adapted from huggingface + if attention_mask is not None: + batch_size = input_ids.shape[0] + attention_mask = attention_mask.view(batch_size, -1) + attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) + attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility + attention_mask = (1.0 - attention_mask) * -10000.0 + + x = self.embed(input_ids) + + for block in self.blocks: + x, attention_mask = block(x, attention_mask) + + x = self.head(self.norm(x)) + + return x + + +def _create_gpt_model(**model_kwargs): + model = GPT(**model_kwargs) + return model + + +@MODELS.register_module +def gpt2_small(**kwargs): + model_kwargs = dict(dim=768, depth=12, num_heads=12, **kwargs) + return _create_gpt_model(**model_kwargs) + + +@MODELS.register_module +def gpt2_medium(**kwargs): + model_kwargs = dict(dim=1024, depth=24, num_heads=16, **kwargs) + return _create_gpt_model(**model_kwargs) + + +@MODELS.register_module +def gpt2_large(**kwargs): + model_kwargs = dict(dim=1280, depth=36, num_heads=20, **kwargs) + return _create_gpt_model(**model_kwargs) + + +@MODELS.register_module +def gpt2_xl(**kwargs): + model_kwargs = dict(dim=1600, depth=48, num_heads=25, **kwargs) + return _create_gpt_model(**model_kwargs) + + +@MODELS.register_module +def gpt3(**kwargs): + model_kwargs = dict(dim=12288, max_position_embeddings=2048, depth=96, num_heads=96, **kwargs) + return _create_gpt_model(**model_kwargs) diff --git a/model_zoo/vit/vit.py b/model_zoo/vit/vit.py index 450f334a4..9bdcbfd38 100644 --- a/model_zoo/vit/vit.py +++ b/model_zoo/vit/vit.py @@ -89,7 +89,7 @@ class ViTEmbedding(nn.Module): @LAYERS.register_module -class ViTSelfAttention(CheckpointModule): +class ViTSelfAttention(nn.Module): def __init__(self, dim: int, num_heads: int, @@ -97,9 +97,8 @@ class ViTSelfAttention(CheckpointModule): dropout: float, bias: bool = True, dtype: dtype = None, - checkpoint: bool = False, init_method: str = 'torch'): - super().__init__(checkpoint) + super().__init__() self.attention_head_size = dim // num_heads self.query_key_value = col_nn.Linear(dim, 3 * dim, @@ -111,7 +110,7 @@ class ViTSelfAttention(CheckpointModule): self.dropout = col_nn.Dropout(dropout) self.softmax = nn.Softmax(dim=-1) - def _forward(self, x): + def forward(self, x): qkv = self.query_key_value(x) all_head_size = qkv.shape[-1] // 3 num_attention_heads = all_head_size // self.attention_head_size @@ -138,7 +137,7 @@ class ViTSelfAttention(CheckpointModule): @LAYERS.register_module -class ViTMLP(CheckpointModule): +class ViTMLP(nn.Module): def __init__(self, dim: int, mlp_ratio: int, @@ -146,9 +145,8 @@ class ViTMLP(CheckpointModule): dropout: float, dtype: dtype = None, bias: bool = True, - checkpoint: bool = False, init_method: str = 'torch'): - super().__init__(checkpoint) + super().__init__() self.dense_1 = col_nn.Linear(dim, mlp_ratio * dim, dtype=dtype, @@ -163,7 +161,7 @@ class ViTMLP(CheckpointModule): **_init_rules[init_method]['transformer']) self.dropout_2 = col_nn.Dropout(dropout) - def _forward(self, x): + def forward(self, x): x = self.dense_1(x) x = self.activation(x) x = self.dropout_1(x) @@ -192,22 +190,22 @@ class ViTHead(nn.Module): self.representation = None representation_size = dim - self.linear = col_nn.Classifier(representation_size, - num_classes, - dtype=dtype, - bias=bias, - **_init_rules[init_method]['head']) + self.dense = col_nn.Classifier(representation_size, + num_classes, + dtype=dtype, + bias=bias, + **_init_rules[init_method]['head']) def forward(self, x): x = x[:, 0] if self.representation is not None: x = self.representation(x) - x = self.linear(x) + x = self.dense(x) return x @LAYERS.register_module -class ViTBlock(nn.Module): +class ViTBlock(CheckpointModule): def __init__(self, dim: int, num_heads: int, @@ -216,32 +214,31 @@ class ViTBlock(nn.Module): attention_dropout: float = 0., dropout: float = 0., drop_path: float = 0., + layernorm_epsilon: float = 1e-6, dtype: dtype = None, bias: bool = True, checkpoint: bool = False, init_method: str = 'torch'): - super().__init__() - self.norm1 = col_nn.LayerNorm(normalized_shape=dim, eps=1e-6, dtype=dtype) + super().__init__(checkpoint) + self.norm1 = col_nn.LayerNorm(normalized_shape=dim, eps=layernorm_epsilon, dtype=dtype) self.attn = ViTSelfAttention(dim=dim, num_heads=num_heads, attention_dropout=attention_dropout, dropout=dropout, bias=bias, dtype=dtype, - checkpoint=checkpoint, init_method=init_method) self.drop_path = col_nn.DropPath(drop_path) if drop_path > 0. else nn.Identity() - self.norm2 = col_nn.LayerNorm(normalized_shape=dim, eps=1e-6, dtype=dtype) + self.norm2 = col_nn.LayerNorm(normalized_shape=dim, eps=layernorm_epsilon, dtype=dtype) self.mlp = ViTMLP(dim=dim, mlp_ratio=mlp_ratio, activation=activation, dropout=dropout, dtype=dtype, bias=bias, - checkpoint=checkpoint, init_method=init_method) - def forward(self, x): + def _forward(self, x): x = x + self.drop_path(self.attn(self.norm1(x))) x = x + self.drop_path(self.mlp(self.norm2(x))) return x @@ -261,6 +258,7 @@ class VisionTransformer(nn.Module): attention_dropout: float = 0., dropout: float = 0.1, drop_path: float = 0., + layernorm_epsilon: float = 1e-6, activation: Callable = nn.functional.gelu, representation_size: int = None, dtype: dtype = None, @@ -295,7 +293,7 @@ class VisionTransformer(nn.Module): ) for i in range(depth) ] - norm = col_nn.LayerNorm(normalized_shape=dim, eps=1e-6, dtype=dtype) + norm = col_nn.LayerNorm(normalized_shape=dim, eps=layernorm_epsilon, dtype=dtype) head = ViTHead(dim=dim, num_classes=num_classes,