[shardformer] add Dropout layer support different dropout pattern (#3856)

* add dropout layer, add dropout test

* modify seed manager as context manager

* add a copy of col_nn.layer

* add dist_crossentropy loss; separate module test

* polish the code

* fix dist crossentropy loss
pull/3943/head
FoolPlayer 2023-06-01 16:21:02 +08:00 committed by FrankLeeeee
parent 997544c1f9
commit 21a3915c98
14 changed files with 1413 additions and 41 deletions

View File

@ -73,7 +73,6 @@ class LinearWithAsyncCommunication(torch.autograd.Function):
total_input = input
grad_input = grad_output.matmul(weight)
grad_output = grad_output.contiguous()
# Convert the tensor shapes to 2D for execution compatibility
grad_output = grad_output.view(grad_output.shape[0] * grad_output.shape[1], grad_output.shape[2])
total_input = total_input.view(total_input.shape[0] * total_input.shape[1], total_input.shape[2])

View File

@ -469,8 +469,7 @@ class Linear1D_Col(ParallelLayer):
if skip_bias_add and not bias:
raise ValueError('cannot skip bias addition if bias is None')
# self.out_features_per_partition = divide(out_features*2, gpc.tensor_parallel_size)
self.out_features_per_partition = out_features
self.out_features_per_partition = divide(out_features, gpc.tensor_parallel_size)
# Parameters.
# Initialize weight.
@ -613,8 +612,7 @@ class Linear1D_Row(ParallelLayer):
raise ValueError('cannot skip bias addition if bias is None')
# Divide the weight matrix along the last dimension.
# self.input_size_per_partition = divide(in_features*2, gpc.tensor_parallel_size)
self.input_size_per_partition = in_features
self.input_size_per_partition = divide(in_features, gpc.tensor_parallel_size)
# Parameters.
# Initialize weight.
@ -886,8 +884,7 @@ class VocabParallelEmbedding1D(ParallelLayer):
tensor_parallel_size = gpc.get_world_size(ParallelMode.PARALLEL_1D)
tensor_parallel_rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
# self.num_embeddings_per_partition = divide(num_embeddings, tensor_parallel_size)
self.num_embeddings_per_partition = num_embeddings
self.num_embeddings_per_partition = divide(num_embeddings, tensor_parallel_size)
self.vocab_start_index = tensor_parallel_rank * self.num_embeddings_per_partition
self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition

View File

@ -257,3 +257,22 @@ CustomPolicy(Policy):
- CLASS `Slicer`:
This class is used to slice tensor according to policy.
3. DistCrossEntropy Loss
- Overview
In order to reduce the communication size, caculate the crossentropy before all gather, refer to [Megatron-LM](https://github.com/NVIDIA/Megatron-LM), reduce the communication size from [batch_size * seq_length * vocab_size] to [batch_size * seq_length]. The origin loss function is:
$$ loss = -\log(\frac{\exp(x[class])}{\sum_i\exp(x[i])})$$
alse can be represented as:
$$ loss = \log(\sum_i\exp(x[i])) - x[class]$$
- Step
- First get the maximum logits across all the devices, make all the logist minus the maximun value to scale the value less than zero to avoid the value of exp being too large
- Get a mask to mask the logits not in the local device
- Caculate the loss according to the second formula

View File

View File

@ -0,0 +1,97 @@
import torch
import torch.distributed as dist
from colossalai.core import global_context as gpc
try:
import fused_mix_prec_layer_norm_cuda
except:
fused_mix_prec_layer_norm_cuda = None
class FusedLayerNormAffineFunction1D(torch.autograd.Function):
r"""Layernorm
Args:
input: input matrix.
weight: weight matrix.
bias: bias matrix.
normalized_shape: input shape from an expected input of size.
:math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] \times \ldots \times \text{normalized_shape}[-1]]`
If a single integer is used, it is treated as a singleton list, and this module will
normalize over the last dimension which is expected to be of that specific size.
eps: a value added to the denominator for numerical stability
"""
@staticmethod
def forward(ctx, input, weight, bias, normalized_shape, eps):
ctx.normalized_shape = normalized_shape
ctx.eps = eps
input_ = input.contiguous()
weight_ = weight.contiguous()
bias_ = bias.contiguous()
output, mean, invvar = fused_mix_prec_layer_norm_cuda.forward_affine(input_, ctx.normalized_shape, weight_,
bias_, ctx.eps)
ctx.save_for_backward(input_, weight_, bias_, mean, invvar)
return output
@staticmethod
def backward(ctx, grad_output):
input_, weight_, bias_, mean, invvar = ctx.saved_tensors
grad_input = grad_weight = grad_bias = None
grad_input, grad_weight, grad_bias \
= fused_mix_prec_layer_norm_cuda.backward_affine(
grad_output.contiguous(), mean, invvar,
input_, ctx.normalized_shape,
weight_, bias_, ctx.eps)
return grad_input, grad_weight, grad_bias, None, None
class LinearWithAsyncCommunication(torch.autograd.Function):
"""
Linear layer execution with asynchronous communication in backprop.
"""
@staticmethod
def forward(ctx, input_, weight, bias, parallel_mode, async_grad_allreduce):
ctx.save_for_backward(input_, weight)
ctx.use_bias = bias is not None
ctx.parallel_mode = parallel_mode
ctx.async_grad_allreduce = async_grad_allreduce
output = torch.matmul(input_, weight.t())
if bias is not None:
output = output + bias
return output
@staticmethod
def backward(ctx, grad_output):
input, weight = ctx.saved_tensors
use_bias = ctx.use_bias
total_input = input
grad_input = grad_output.matmul(weight)
grad_output = grad_output.contiguous()
# Convert the tensor shapes to 2D for execution compatibility
grad_output = grad_output.view(grad_output.shape[0] * grad_output.shape[1], grad_output.shape[2])
total_input = total_input.view(total_input.shape[0] * total_input.shape[1], total_input.shape[2])
if ctx.async_grad_allreduce:
# Asynchronous all-reduce
handle = dist.all_reduce(grad_input, group=gpc.get_group(ctx.parallel_mode), async_op=True)
# Delay the start of weight gradient computation shortly (3us) to have
# all-reduce scheduled first and have GPU resources allocated
_ = torch.empty(1, device=grad_output.device) + 1
grad_weight = grad_output.t().matmul(total_input)
grad_bias = grad_output.sum(dim=0) if use_bias else None
if ctx.async_grad_allreduce:
handle.wait()
return grad_input, grad_weight, grad_bias, None, None, None
def linear_with_async_comm(input_, weight, bias, parallel_mode, async_grad_allreduce):
return LinearWithAsyncCommunication.apply(input_, weight, bias, parallel_mode, async_grad_allreduce)

View File

@ -0,0 +1,105 @@
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Function
class DistCrossEntropy(Function):
r"""
Overwrite the forward and backward function to calculate the cross entropy loss before gather
Args:
Function (:class:`torch.autograd.Function`): default
"""
@staticmethod
def forward(ctx, vocab_logits: torch.Tensor, target: torch.Tensor):
r"""
Calculate the cross entropy loss before gather, the origin loss function is as follows:
loss = -log(exp(x[class])/sum(exp(x[i]))
and can be rewrite as:
loss = log(sum(exp(x[i])) - x[class]
To avoid the `nan` of log(sim(exp(x[i]))), we minus the max of x[i]
Args:
vocab_logits (:class:`torch.Tensor`): The logits of the vocabulary, shape is
[batch_size, seq_len, vocab_size]
labels (:class:`torch.Tensor`): The labels of the vocabulary, shape is
[batch_size, seq_len]
Returns:
:class:`torch.Tensor`: The cross entropy loss
"""
# get the max
logits_max = torch.max(vocab_logits, dim=-1)[0]
dist.all_reduce(logits_max, op=dist.ReduceOp.MAX)
# minus the max to avoid the result of sum of exp is too large and the log is nan
vocab_logits = vocab_logits - logits_max.unsqueeze(dim=-1)
# mask the target in the local device
partition_vocab_size = vocab_logits.size()[-1]
rank = dist.get_rank()
world_size = dist.get_world_size()
global_vocab_size = partition_vocab_size * world_size
# [down, up) => false, other device and -100 => true
delta = (global_vocab_size + world_size - 1) // world_size
down_shreshold = rank * delta
up_shreshold = down_shreshold + delta
mask = (target < down_shreshold) | (target >= up_shreshold)
masked_target = target.clone() - down_shreshold
masked_target[mask] = 0
# reshape the logist and target
# reshape the vocab_logits to [bath_size * seq_len, vocab_size]
# reshape the labels to [bath_size * seq_len]
logits_2d = vocab_logits.view(-1, partition_vocab_size)
masked_target_1d = masked_target.view(-1)
# extract the x[class] and set the x[other device] to zero
pred_logits_1d = logits_2d[torch.arange(start=0, end=logits_2d.shape[0], device=logits_2d.device),
masked_target_1d]
pred_logits_1d = pred_logits_1d.clone().contiguous()
pred_logits = pred_logits_1d.view_as(target)
pred_logits[mask] = 0.0
# allreduce the get all x(i,y)
dist.all_reduce(pred_logits, op=dist.ReduceOp.SUM)
exp_logits = vocab_logits
torch.exp(vocab_logits, out=exp_logits)
sum_exp_logits = torch.sum(exp_logits, dim=-1)
dist.all_reduce(sum_exp_logits, op=dist.ReduceOp.SUM)
# calculate the loss
# loss = log(sum(exp(x[i]))) - x[class]
loss = torch.log(sum_exp_logits) - pred_logits
loss = torch.sum(loss).div_(loss.numel())
# caculate the softmax
exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1))
ctx.save_for_backward(exp_logits, mask, masked_target_1d)
return loss
@staticmethod
def backward(ctx, grad_output):
# retrieve the saved tensors
exp_logits, mask, masked_target_1d = ctx.saved_tensors
# use exp logits as the input grad
grad_logits = exp_logits
partion_vocab_size = grad_logits.shape[-1]
grad_logits_2d = grad_logits.view(-1, partion_vocab_size)
update = 1.0 - mask.view(-1).float()
grad_logits_2d[torch.arange(0, grad_logits_2d.shape[0]), masked_target_1d] -= update
grad_logits.mul_(grad_output.unsqueeze(dim=-1))
return grad_logits, None, None
def applyDistCrossEntropy(vocab_logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
return DistCrossEntropy.apply(vocab_logits, labels)

View File

@ -0,0 +1,58 @@
import os
import time
from contextlib import contextmanager
import torch
import torch.nn as nn
class SeedManager:
"""
This class is a random state manager to change random state for different random seed.
"""
def __init__(self):
original_state = torch.cuda.get_rng_state()
seed = int(f"{int(time.time())}{os.environ['RANK']}")
torch.cuda.manual_seed(int(seed))
self.dropout_state = torch.cuda.get_rng_state()
torch.cuda.set_rng_state(original_state)
def set_mode(self, rng_state):
torch.cuda.set_rng_state(rng_state)
def get_current_mode(self):
current_state = torch.cuda.get_rng_state()
return current_state
@contextmanager
def dropout_mode(self):
"""
This is a context manager to change the dropout state and recover the original state.
Usage:
::
>>> with _seed_manager.dropout_mode():
>>> input = super().forward(input)
"""
try:
current_mode = self.get_current_mode()
yield self.set_mode(self.dropout_state)
finally:
self.dropout_state = self.get_current_mode()
self.set_mode(current_mode)
_seed_manager = SeedManager()
class Dropout1D(nn.Dropout):
def __init__(self, p=0.5, inplace=False):
super().__init__(p, inplace)
def forward(self, input):
with _seed_manager.dropout_mode():
input = super().forward(input)
return input

File diff suppressed because it is too large Load Diff

View File

@ -6,6 +6,8 @@ from torch.nn import CrossEntropyLoss
from transformers import BertForMaskedLM
from transformers.models.bert.modeling_bert import MaskedLMOutput
from ..layer.dist_crossentropy import applyDistCrossEntropy
class BertForMaskedLM_(BertForMaskedLM):
@ -47,11 +49,11 @@ class BertForMaskedLM_(BertForMaskedLM):
masked_lm_loss = None
# if input_ids is not None:
# masked_lm_loss = applyDistCrossEntropy(prediction_scores, input_ids, self.config.vocab_size)
if labels is not None:
loss_fct = CrossEntropyLoss() # -100 index = padding token
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
masked_lm_loss = applyDistCrossEntropy(prediction_scores, labels)
# if labels is not None:
# loss_fct = CrossEntropyLoss() # -100 index = padding token
# masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
if not return_dict:
output = (prediction_scores,) + outputs[2:]

View File

@ -7,8 +7,6 @@ import torch
import torch.nn as nn
from transformers import AutoConfig
import colossalai.nn as col_nn
@dataclass
class Argument:

View File

@ -4,7 +4,7 @@ from typing import Any, Callable, Dict, List, Tuple, Type
import torch.nn as nn
from transformers.models.bert.modeling_bert import BertEmbeddings, BertLayer, BertLMPredictionHead
import colossalai.nn as col_nn
import colossalai.shardformer.layer.layers as col_nn
from .basepolicy import Argument, Col_Layer, Layer, Policy, Row_Layer
@ -142,7 +142,7 @@ class BertPolicy(Policy):
weight="decoder.weight",
bias="decoder.bias",
replace_layer=col_nn.Linear1D_Col,
gather_output=True,
# gather_output=True,
)
]

View File

@ -94,10 +94,7 @@ class Slicer():
Returns:
:class:`torch.Tensor`: The sliced tensor
"""
delta = (tensor.shape[0] + self.shardconfig.world_size - 1) // self.shardconfig.world_size
down_idx = self.shardconfig.rank * delta
up_idx = down_idx + delta
return tensor[down_idx:up_idx].contiguous()
return tensor.chunk(self.shardconfig.world_size, dim=0)[self.shardconfig.rank].contiguous()
def slice_col(
self,
@ -113,10 +110,7 @@ class Slicer():
:class:`torch.Tensor`: The sliced tensor
"""
delta = (tensor.shape[0] + self.shardconfig.world_size - 1) // self.shardconfig.world_size
down_idx = self.shardconfig.rank * delta
up_idx = down_idx + delta
return tensor[down_idx:up_idx, :].contiguous()
return tensor.chunk(self.shardconfig.world_size, dim=0)[self.shardconfig.rank].contiguous()
def slice_row(
self,
@ -131,7 +125,4 @@ class Slicer():
Returns:
:class:`torch.Tensor`: The sliced tensor
"""
delta = (tensor.shape[1] + self.shardconfig.world_size - 1) // self.shardconfig.world_size
down_idx = self.shardconfig.rank * delta
up_idx = down_idx + delta
return tensor[:, down_idx:up_idx].contiguous()
return tensor.chunk(self.shardconfig.world_size, dim=1)[self.shardconfig.rank].contiguous()

View File

@ -0,0 +1,50 @@
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import colossalai
from colossalai.shardformer.layer.dist_crossentropy import applyDistCrossEntropy
from colossalai.shardformer.layer.dropout import Dropout1D
def get_args():
parser = colossalai.get_default_parser()
parser.add_argument("--module", type=str, default='distloss')
return parser.parse_args()
def test_dist_crossentropy():
pred = torch.randn(2, 4, 8, requires_grad=True)
labels = torch.randint(8, (1, 4)).repeat(2, 1)
pred_ = pred.view(-1, 8)
labels_ = labels.view(-1)
loss = F.cross_entropy(pred_, labels_)
loss.backward()
print(f"normal loss:{loss}")
pred = pred.chunk(int(os.environ['WORLD_SIZE']), -1)[int(os.environ['RANK'])]
loss = applyDistCrossEntropy(pred.to('cuda'), labels.to('cuda'))
loss.backward()
print(f"dist loss:{loss}")
def test_dropout():
input = torch.randn(5, 4).to("cuda")
m = Dropout1D(p=0.2).to("cuda")
for i in range(2):
print(f"Output: {m(input)}")
print(torch.randn(1))
if __name__ == '__main__':
args = get_args()
colossalai.launch_from_torch(config={})
if args.module == 'distloss':
test_dist_crossentropy()
elif args.module == 'dropout':
test_dropout()
else:
print("not implemented yet")

View File

@ -1,11 +1,12 @@
import os
import random
import torch
import torch.nn as nn
from datasets import load_dataset
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from transformers import AutoTokenizer, BertForMaskedLM, DataCollatorForLanguageModeling
from transformers import AutoTokenizer, BertForMaskedLM, DataCollatorForLanguageModeling, get_scheduler
import colossalai
from colossalai.shardformer.shard import ShardConfig, shard_model
@ -18,6 +19,7 @@ tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
def get_args():
parser = colossalai.get_default_parser()
parser.add_argument("--mode", type=str, default='inference')
parser.add_argument("--save_model", action='store_true')
return parser.parse_args()
@ -30,36 +32,40 @@ def load_data():
# tokenized_datasets=tokenized_datasets.rename_column("label","labels")
tokenized_datasets.set_format("torch")
train_dataset = tokenized_datasets["train"].select(range(500))
test_dataset = tokenized_datasets["test"].select(range(100))
train_dataset = tokenized_datasets["train"]
test_dataset = tokenized_datasets["test"]
datacollector = DataCollatorForLanguageModeling(tokenizer, mlm=True, mlm_probability=0.15, return_tensors="pt")
train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True, collate_fn=datacollector)
eval_dataloader = DataLoader(test_dataset, batch_size=8, shuffle=True, collate_fn=datacollector)
train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True, collate_fn=datacollector)
eval_dataloader = DataLoader(test_dataset, batch_size=16, shuffle=True, collate_fn=datacollector)
return train_dataloader, eval_dataloader
def inference(model: nn.Module):
print(model)
def inference(model: nn.Module, args):
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
token = "Hello, my dog is cute"
inputs = tokenizer(token, return_tensors="pt")
inputs.to("cuda")
model.eval()
model.to("cuda")
outputs = model(**inputs)
print(outputs)
def train(model: nn.Module, num_epoch: int = 2):
def train(model: nn.Module, args, num_epoch: int = 3):
train_dataloader, eval_dataloader = load_data()
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
progress_bar = tqdm(range((num_epoch) * len(train_dataloader)))
criterion = nn.CrossEntropyLoss()
num_training = num_epoch * len(train_dataloader)
progress_bar = tqdm(range(num_training))
lr_scheduler = get_scheduler(name="linear",
optimizer=optimizer,
num_warmup_steps=0,
num_training_steps=num_training)
best_test_loss = float("inf")
model.to("cuda")
model.train()
for epoch in range(num_epoch):
progress_bar.set_description(f"Rank {get_current_device()} epoch {epoch}")
for batch in train_dataloader:
optimizer.zero_grad()
batch = {k: v.to('cuda') for k, v in batch.items()}
@ -67,6 +73,7 @@ def train(model: nn.Module, num_epoch: int = 2):
loss = outputs.loss
loss.backward()
optimizer.step()
lr_scheduler.step()
progress_bar.update(1)
train_loss = loss
@ -75,16 +82,20 @@ def train(model: nn.Module, num_epoch: int = 2):
batch = {k: v.to('cuda') for k, v in batch.items()}
outputs = model(**batch)
# loss = outputs.loss
assert not torch.isnan(outputs.loss), f"{batch}"
loss += outputs.loss.item()
# loss = criterion(outputs.logits, batch["input_ids"])
test_loss = loss / len(eval_dataloader)
print_rank_0(f"Train Loss: {train_loss:.4f} Test Loss:{test_loss:.4f}")
if args.save_model and test_loss < best_test_loss:
best_test_loss = test_loss
torch.save(model.state_dict(), "./checkpoints/best_model.pth")
if __name__ == "__main__":
args = get_args()
colossalai.launch_from_torch(config=args.config)
model = BertForMaskedLM.from_pretrained("bert-base-uncased")
colossalai.launch_from_torch(config=args.config)
shard_config = ShardConfig(
rank=int(str(get_current_device()).split(':')[-1]),
world_size=int(os.environ['WORLD_SIZE']),
@ -92,6 +103,8 @@ if __name__ == "__main__":
sharded_model = shard_model(model, shard_config)
if args.mode == "train":
train(sharded_model)
train(sharded_model, args)
elif args.mode == "inference":
inference(sharded_model)
inference(sharded_model, args)
else:
raise NotImplementedError