mirror of https://github.com/hpcaitech/ColossalAI
[shardformer] add Dropout layer support different dropout pattern (#3856)
* add dropout layer, add dropout test * modify seed manager as context manager * add a copy of col_nn.layer * add dist_crossentropy loss; separate module test * polish the code * fix dist crossentropy losspull/3943/head
parent
997544c1f9
commit
21a3915c98
|
@ -73,7 +73,6 @@ class LinearWithAsyncCommunication(torch.autograd.Function):
|
|||
total_input = input
|
||||
grad_input = grad_output.matmul(weight)
|
||||
|
||||
grad_output = grad_output.contiguous()
|
||||
# Convert the tensor shapes to 2D for execution compatibility
|
||||
grad_output = grad_output.view(grad_output.shape[0] * grad_output.shape[1], grad_output.shape[2])
|
||||
total_input = total_input.view(total_input.shape[0] * total_input.shape[1], total_input.shape[2])
|
||||
|
|
|
@ -469,8 +469,7 @@ class Linear1D_Col(ParallelLayer):
|
|||
if skip_bias_add and not bias:
|
||||
raise ValueError('cannot skip bias addition if bias is None')
|
||||
|
||||
# self.out_features_per_partition = divide(out_features*2, gpc.tensor_parallel_size)
|
||||
self.out_features_per_partition = out_features
|
||||
self.out_features_per_partition = divide(out_features, gpc.tensor_parallel_size)
|
||||
|
||||
# Parameters.
|
||||
# Initialize weight.
|
||||
|
@ -613,8 +612,7 @@ class Linear1D_Row(ParallelLayer):
|
|||
raise ValueError('cannot skip bias addition if bias is None')
|
||||
|
||||
# Divide the weight matrix along the last dimension.
|
||||
# self.input_size_per_partition = divide(in_features*2, gpc.tensor_parallel_size)
|
||||
self.input_size_per_partition = in_features
|
||||
self.input_size_per_partition = divide(in_features, gpc.tensor_parallel_size)
|
||||
|
||||
# Parameters.
|
||||
# Initialize weight.
|
||||
|
@ -886,8 +884,7 @@ class VocabParallelEmbedding1D(ParallelLayer):
|
|||
|
||||
tensor_parallel_size = gpc.get_world_size(ParallelMode.PARALLEL_1D)
|
||||
tensor_parallel_rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
|
||||
# self.num_embeddings_per_partition = divide(num_embeddings, tensor_parallel_size)
|
||||
self.num_embeddings_per_partition = num_embeddings
|
||||
self.num_embeddings_per_partition = divide(num_embeddings, tensor_parallel_size)
|
||||
self.vocab_start_index = tensor_parallel_rank * self.num_embeddings_per_partition
|
||||
self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition
|
||||
|
||||
|
|
|
@ -257,3 +257,22 @@ CustomPolicy(Policy):
|
|||
- CLASS `Slicer`:
|
||||
|
||||
This class is used to slice tensor according to policy.
|
||||
|
||||
|
||||
3. DistCrossEntropy Loss
|
||||
- Overview
|
||||
|
||||
In order to reduce the communication size, caculate the crossentropy before all gather, refer to [Megatron-LM](https://github.com/NVIDIA/Megatron-LM), reduce the communication size from [batch_size * seq_length * vocab_size] to [batch_size * seq_length]. The origin loss function is:
|
||||
$$ loss = -\log(\frac{\exp(x[class])}{\sum_i\exp(x[i])})$$
|
||||
|
||||
alse can be represented as:
|
||||
|
||||
$$ loss = \log(\sum_i\exp(x[i])) - x[class]$$
|
||||
|
||||
- Step
|
||||
|
||||
- First get the maximum logits across all the devices, make all the logist minus the maximun value to scale the value less than zero to avoid the value of exp being too large
|
||||
|
||||
- Get a mask to mask the logits not in the local device
|
||||
|
||||
- Caculate the loss according to the second formula
|
||||
|
|
|
@ -0,0 +1,97 @@
|
|||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from colossalai.core import global_context as gpc
|
||||
|
||||
try:
|
||||
import fused_mix_prec_layer_norm_cuda
|
||||
except:
|
||||
fused_mix_prec_layer_norm_cuda = None
|
||||
|
||||
|
||||
class FusedLayerNormAffineFunction1D(torch.autograd.Function):
|
||||
r"""Layernorm
|
||||
|
||||
Args:
|
||||
input: input matrix.
|
||||
weight: weight matrix.
|
||||
bias: bias matrix.
|
||||
normalized_shape: input shape from an expected input of size.
|
||||
:math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] \times \ldots \times \text{normalized_shape}[-1]]`
|
||||
If a single integer is used, it is treated as a singleton list, and this module will
|
||||
normalize over the last dimension which is expected to be of that specific size.
|
||||
eps: a value added to the denominator for numerical stability
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input, weight, bias, normalized_shape, eps):
|
||||
ctx.normalized_shape = normalized_shape
|
||||
ctx.eps = eps
|
||||
input_ = input.contiguous()
|
||||
weight_ = weight.contiguous()
|
||||
bias_ = bias.contiguous()
|
||||
output, mean, invvar = fused_mix_prec_layer_norm_cuda.forward_affine(input_, ctx.normalized_shape, weight_,
|
||||
bias_, ctx.eps)
|
||||
ctx.save_for_backward(input_, weight_, bias_, mean, invvar)
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
input_, weight_, bias_, mean, invvar = ctx.saved_tensors
|
||||
grad_input = grad_weight = grad_bias = None
|
||||
grad_input, grad_weight, grad_bias \
|
||||
= fused_mix_prec_layer_norm_cuda.backward_affine(
|
||||
grad_output.contiguous(), mean, invvar,
|
||||
input_, ctx.normalized_shape,
|
||||
weight_, bias_, ctx.eps)
|
||||
|
||||
return grad_input, grad_weight, grad_bias, None, None
|
||||
|
||||
|
||||
class LinearWithAsyncCommunication(torch.autograd.Function):
|
||||
"""
|
||||
Linear layer execution with asynchronous communication in backprop.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_, weight, bias, parallel_mode, async_grad_allreduce):
|
||||
ctx.save_for_backward(input_, weight)
|
||||
ctx.use_bias = bias is not None
|
||||
ctx.parallel_mode = parallel_mode
|
||||
ctx.async_grad_allreduce = async_grad_allreduce
|
||||
|
||||
output = torch.matmul(input_, weight.t())
|
||||
if bias is not None:
|
||||
output = output + bias
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
input, weight = ctx.saved_tensors
|
||||
use_bias = ctx.use_bias
|
||||
|
||||
total_input = input
|
||||
grad_input = grad_output.matmul(weight)
|
||||
grad_output = grad_output.contiguous()
|
||||
# Convert the tensor shapes to 2D for execution compatibility
|
||||
grad_output = grad_output.view(grad_output.shape[0] * grad_output.shape[1], grad_output.shape[2])
|
||||
total_input = total_input.view(total_input.shape[0] * total_input.shape[1], total_input.shape[2])
|
||||
|
||||
if ctx.async_grad_allreduce:
|
||||
# Asynchronous all-reduce
|
||||
handle = dist.all_reduce(grad_input, group=gpc.get_group(ctx.parallel_mode), async_op=True)
|
||||
# Delay the start of weight gradient computation shortly (3us) to have
|
||||
# all-reduce scheduled first and have GPU resources allocated
|
||||
_ = torch.empty(1, device=grad_output.device) + 1
|
||||
|
||||
grad_weight = grad_output.t().matmul(total_input)
|
||||
grad_bias = grad_output.sum(dim=0) if use_bias else None
|
||||
|
||||
if ctx.async_grad_allreduce:
|
||||
handle.wait()
|
||||
|
||||
return grad_input, grad_weight, grad_bias, None, None, None
|
||||
|
||||
|
||||
def linear_with_async_comm(input_, weight, bias, parallel_mode, async_grad_allreduce):
|
||||
return LinearWithAsyncCommunication.apply(input_, weight, bias, parallel_mode, async_grad_allreduce)
|
|
@ -0,0 +1,105 @@
|
|||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.autograd import Function
|
||||
|
||||
|
||||
class DistCrossEntropy(Function):
|
||||
r"""
|
||||
Overwrite the forward and backward function to calculate the cross entropy loss before gather
|
||||
|
||||
Args:
|
||||
Function (:class:`torch.autograd.Function`): default
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, vocab_logits: torch.Tensor, target: torch.Tensor):
|
||||
r"""
|
||||
Calculate the cross entropy loss before gather, the origin loss function is as follows:
|
||||
loss = -log(exp(x[class])/sum(exp(x[i]))
|
||||
and can be rewrite as:
|
||||
loss = log(sum(exp(x[i])) - x[class]
|
||||
|
||||
To avoid the `nan` of log(sim(exp(x[i]))), we minus the max of x[i]
|
||||
|
||||
Args:
|
||||
vocab_logits (:class:`torch.Tensor`): The logits of the vocabulary, shape is
|
||||
[batch_size, seq_len, vocab_size]
|
||||
labels (:class:`torch.Tensor`): The labels of the vocabulary, shape is
|
||||
[batch_size, seq_len]
|
||||
|
||||
Returns:
|
||||
:class:`torch.Tensor`: The cross entropy loss
|
||||
"""
|
||||
# get the max
|
||||
logits_max = torch.max(vocab_logits, dim=-1)[0]
|
||||
dist.all_reduce(logits_max, op=dist.ReduceOp.MAX)
|
||||
|
||||
# minus the max to avoid the result of sum of exp is too large and the log is nan
|
||||
vocab_logits = vocab_logits - logits_max.unsqueeze(dim=-1)
|
||||
|
||||
# mask the target in the local device
|
||||
partition_vocab_size = vocab_logits.size()[-1]
|
||||
rank = dist.get_rank()
|
||||
world_size = dist.get_world_size()
|
||||
global_vocab_size = partition_vocab_size * world_size
|
||||
|
||||
# [down, up) => false, other device and -100 => true
|
||||
delta = (global_vocab_size + world_size - 1) // world_size
|
||||
down_shreshold = rank * delta
|
||||
up_shreshold = down_shreshold + delta
|
||||
mask = (target < down_shreshold) | (target >= up_shreshold)
|
||||
masked_target = target.clone() - down_shreshold
|
||||
masked_target[mask] = 0
|
||||
|
||||
# reshape the logist and target
|
||||
# reshape the vocab_logits to [bath_size * seq_len, vocab_size]
|
||||
# reshape the labels to [bath_size * seq_len]
|
||||
logits_2d = vocab_logits.view(-1, partition_vocab_size)
|
||||
masked_target_1d = masked_target.view(-1)
|
||||
|
||||
# extract the x[class] and set the x[other device] to zero
|
||||
pred_logits_1d = logits_2d[torch.arange(start=0, end=logits_2d.shape[0], device=logits_2d.device),
|
||||
masked_target_1d]
|
||||
pred_logits_1d = pred_logits_1d.clone().contiguous()
|
||||
pred_logits = pred_logits_1d.view_as(target)
|
||||
pred_logits[mask] = 0.0
|
||||
|
||||
# allreduce the get all x(i,y)
|
||||
dist.all_reduce(pred_logits, op=dist.ReduceOp.SUM)
|
||||
exp_logits = vocab_logits
|
||||
torch.exp(vocab_logits, out=exp_logits)
|
||||
sum_exp_logits = torch.sum(exp_logits, dim=-1)
|
||||
dist.all_reduce(sum_exp_logits, op=dist.ReduceOp.SUM)
|
||||
|
||||
# calculate the loss
|
||||
# loss = log(sum(exp(x[i]))) - x[class]
|
||||
loss = torch.log(sum_exp_logits) - pred_logits
|
||||
loss = torch.sum(loss).div_(loss.numel())
|
||||
|
||||
# caculate the softmax
|
||||
exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1))
|
||||
ctx.save_for_backward(exp_logits, mask, masked_target_1d)
|
||||
|
||||
return loss
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
# retrieve the saved tensors
|
||||
exp_logits, mask, masked_target_1d = ctx.saved_tensors
|
||||
|
||||
# use exp logits as the input grad
|
||||
grad_logits = exp_logits
|
||||
partion_vocab_size = grad_logits.shape[-1]
|
||||
grad_logits_2d = grad_logits.view(-1, partion_vocab_size)
|
||||
|
||||
update = 1.0 - mask.view(-1).float()
|
||||
grad_logits_2d[torch.arange(0, grad_logits_2d.shape[0]), masked_target_1d] -= update
|
||||
|
||||
grad_logits.mul_(grad_output.unsqueeze(dim=-1))
|
||||
return grad_logits, None, None
|
||||
|
||||
|
||||
def applyDistCrossEntropy(vocab_logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
|
||||
return DistCrossEntropy.apply(vocab_logits, labels)
|
|
@ -0,0 +1,58 @@
|
|||
import os
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class SeedManager:
|
||||
"""
|
||||
This class is a random state manager to change random state for different random seed.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
original_state = torch.cuda.get_rng_state()
|
||||
seed = int(f"{int(time.time())}{os.environ['RANK']}")
|
||||
torch.cuda.manual_seed(int(seed))
|
||||
self.dropout_state = torch.cuda.get_rng_state()
|
||||
torch.cuda.set_rng_state(original_state)
|
||||
|
||||
def set_mode(self, rng_state):
|
||||
torch.cuda.set_rng_state(rng_state)
|
||||
|
||||
def get_current_mode(self):
|
||||
current_state = torch.cuda.get_rng_state()
|
||||
return current_state
|
||||
|
||||
@contextmanager
|
||||
def dropout_mode(self):
|
||||
"""
|
||||
This is a context manager to change the dropout state and recover the original state.
|
||||
|
||||
Usage:
|
||||
::
|
||||
>>> with _seed_manager.dropout_mode():
|
||||
>>> input = super().forward(input)
|
||||
"""
|
||||
try:
|
||||
current_mode = self.get_current_mode()
|
||||
yield self.set_mode(self.dropout_state)
|
||||
finally:
|
||||
self.dropout_state = self.get_current_mode()
|
||||
self.set_mode(current_mode)
|
||||
|
||||
|
||||
_seed_manager = SeedManager()
|
||||
|
||||
|
||||
class Dropout1D(nn.Dropout):
|
||||
|
||||
def __init__(self, p=0.5, inplace=False):
|
||||
super().__init__(p, inplace)
|
||||
|
||||
def forward(self, input):
|
||||
with _seed_manager.dropout_mode():
|
||||
input = super().forward(input)
|
||||
return input
|
File diff suppressed because it is too large
Load Diff
|
@ -6,6 +6,8 @@ from torch.nn import CrossEntropyLoss
|
|||
from transformers import BertForMaskedLM
|
||||
from transformers.models.bert.modeling_bert import MaskedLMOutput
|
||||
|
||||
from ..layer.dist_crossentropy import applyDistCrossEntropy
|
||||
|
||||
|
||||
class BertForMaskedLM_(BertForMaskedLM):
|
||||
|
||||
|
@ -47,11 +49,11 @@ class BertForMaskedLM_(BertForMaskedLM):
|
|||
|
||||
masked_lm_loss = None
|
||||
|
||||
# if input_ids is not None:
|
||||
# masked_lm_loss = applyDistCrossEntropy(prediction_scores, input_ids, self.config.vocab_size)
|
||||
if labels is not None:
|
||||
loss_fct = CrossEntropyLoss() # -100 index = padding token
|
||||
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
|
||||
masked_lm_loss = applyDistCrossEntropy(prediction_scores, labels)
|
||||
# if labels is not None:
|
||||
# loss_fct = CrossEntropyLoss() # -100 index = padding token
|
||||
# masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
|
||||
|
||||
if not return_dict:
|
||||
output = (prediction_scores,) + outputs[2:]
|
||||
|
|
|
@ -7,8 +7,6 @@ import torch
|
|||
import torch.nn as nn
|
||||
from transformers import AutoConfig
|
||||
|
||||
import colossalai.nn as col_nn
|
||||
|
||||
|
||||
@dataclass
|
||||
class Argument:
|
||||
|
|
|
@ -4,7 +4,7 @@ from typing import Any, Callable, Dict, List, Tuple, Type
|
|||
import torch.nn as nn
|
||||
from transformers.models.bert.modeling_bert import BertEmbeddings, BertLayer, BertLMPredictionHead
|
||||
|
||||
import colossalai.nn as col_nn
|
||||
import colossalai.shardformer.layer.layers as col_nn
|
||||
|
||||
from .basepolicy import Argument, Col_Layer, Layer, Policy, Row_Layer
|
||||
|
||||
|
@ -142,7 +142,7 @@ class BertPolicy(Policy):
|
|||
weight="decoder.weight",
|
||||
bias="decoder.bias",
|
||||
replace_layer=col_nn.Linear1D_Col,
|
||||
gather_output=True,
|
||||
# gather_output=True,
|
||||
)
|
||||
]
|
||||
|
||||
|
|
|
@ -94,10 +94,7 @@ class Slicer():
|
|||
Returns:
|
||||
:class:`torch.Tensor`: The sliced tensor
|
||||
"""
|
||||
delta = (tensor.shape[0] + self.shardconfig.world_size - 1) // self.shardconfig.world_size
|
||||
down_idx = self.shardconfig.rank * delta
|
||||
up_idx = down_idx + delta
|
||||
return tensor[down_idx:up_idx].contiguous()
|
||||
return tensor.chunk(self.shardconfig.world_size, dim=0)[self.shardconfig.rank].contiguous()
|
||||
|
||||
def slice_col(
|
||||
self,
|
||||
|
@ -113,10 +110,7 @@ class Slicer():
|
|||
:class:`torch.Tensor`: The sliced tensor
|
||||
|
||||
"""
|
||||
delta = (tensor.shape[0] + self.shardconfig.world_size - 1) // self.shardconfig.world_size
|
||||
down_idx = self.shardconfig.rank * delta
|
||||
up_idx = down_idx + delta
|
||||
return tensor[down_idx:up_idx, :].contiguous()
|
||||
return tensor.chunk(self.shardconfig.world_size, dim=0)[self.shardconfig.rank].contiguous()
|
||||
|
||||
def slice_row(
|
||||
self,
|
||||
|
@ -131,7 +125,4 @@ class Slicer():
|
|||
Returns:
|
||||
:class:`torch.Tensor`: The sliced tensor
|
||||
"""
|
||||
delta = (tensor.shape[1] + self.shardconfig.world_size - 1) // self.shardconfig.world_size
|
||||
down_idx = self.shardconfig.rank * delta
|
||||
up_idx = down_idx + delta
|
||||
return tensor[:, down_idx:up_idx].contiguous()
|
||||
return tensor.chunk(self.shardconfig.world_size, dim=1)[self.shardconfig.rank].contiguous()
|
||||
|
|
|
@ -0,0 +1,50 @@
|
|||
import os
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import colossalai
|
||||
from colossalai.shardformer.layer.dist_crossentropy import applyDistCrossEntropy
|
||||
from colossalai.shardformer.layer.dropout import Dropout1D
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = colossalai.get_default_parser()
|
||||
parser.add_argument("--module", type=str, default='distloss')
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def test_dist_crossentropy():
|
||||
pred = torch.randn(2, 4, 8, requires_grad=True)
|
||||
labels = torch.randint(8, (1, 4)).repeat(2, 1)
|
||||
|
||||
pred_ = pred.view(-1, 8)
|
||||
labels_ = labels.view(-1)
|
||||
loss = F.cross_entropy(pred_, labels_)
|
||||
loss.backward()
|
||||
print(f"normal loss:{loss}")
|
||||
|
||||
pred = pred.chunk(int(os.environ['WORLD_SIZE']), -1)[int(os.environ['RANK'])]
|
||||
loss = applyDistCrossEntropy(pred.to('cuda'), labels.to('cuda'))
|
||||
loss.backward()
|
||||
print(f"dist loss:{loss}")
|
||||
|
||||
|
||||
def test_dropout():
|
||||
input = torch.randn(5, 4).to("cuda")
|
||||
m = Dropout1D(p=0.2).to("cuda")
|
||||
for i in range(2):
|
||||
print(f"Output: {m(input)}")
|
||||
print(torch.randn(1))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = get_args()
|
||||
colossalai.launch_from_torch(config={})
|
||||
if args.module == 'distloss':
|
||||
test_dist_crossentropy()
|
||||
elif args.module == 'dropout':
|
||||
test_dropout()
|
||||
else:
|
||||
print("not implemented yet")
|
|
@ -1,11 +1,12 @@
|
|||
import os
|
||||
import random
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from datasets import load_dataset
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm.auto import tqdm
|
||||
from transformers import AutoTokenizer, BertForMaskedLM, DataCollatorForLanguageModeling
|
||||
from transformers import AutoTokenizer, BertForMaskedLM, DataCollatorForLanguageModeling, get_scheduler
|
||||
|
||||
import colossalai
|
||||
from colossalai.shardformer.shard import ShardConfig, shard_model
|
||||
|
@ -18,6 +19,7 @@ tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
|
|||
def get_args():
|
||||
parser = colossalai.get_default_parser()
|
||||
parser.add_argument("--mode", type=str, default='inference')
|
||||
parser.add_argument("--save_model", action='store_true')
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
|
@ -30,36 +32,40 @@ def load_data():
|
|||
# tokenized_datasets=tokenized_datasets.rename_column("label","labels")
|
||||
tokenized_datasets.set_format("torch")
|
||||
|
||||
train_dataset = tokenized_datasets["train"].select(range(500))
|
||||
test_dataset = tokenized_datasets["test"].select(range(100))
|
||||
train_dataset = tokenized_datasets["train"]
|
||||
test_dataset = tokenized_datasets["test"]
|
||||
|
||||
datacollector = DataCollatorForLanguageModeling(tokenizer, mlm=True, mlm_probability=0.15, return_tensors="pt")
|
||||
train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True, collate_fn=datacollector)
|
||||
eval_dataloader = DataLoader(test_dataset, batch_size=8, shuffle=True, collate_fn=datacollector)
|
||||
train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True, collate_fn=datacollector)
|
||||
eval_dataloader = DataLoader(test_dataset, batch_size=16, shuffle=True, collate_fn=datacollector)
|
||||
return train_dataloader, eval_dataloader
|
||||
|
||||
|
||||
def inference(model: nn.Module):
|
||||
print(model)
|
||||
def inference(model: nn.Module, args):
|
||||
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
|
||||
token = "Hello, my dog is cute"
|
||||
inputs = tokenizer(token, return_tensors="pt")
|
||||
inputs.to("cuda")
|
||||
model.eval()
|
||||
model.to("cuda")
|
||||
outputs = model(**inputs)
|
||||
print(outputs)
|
||||
|
||||
|
||||
def train(model: nn.Module, num_epoch: int = 2):
|
||||
def train(model: nn.Module, args, num_epoch: int = 3):
|
||||
train_dataloader, eval_dataloader = load_data()
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
|
||||
progress_bar = tqdm(range((num_epoch) * len(train_dataloader)))
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
num_training = num_epoch * len(train_dataloader)
|
||||
progress_bar = tqdm(range(num_training))
|
||||
lr_scheduler = get_scheduler(name="linear",
|
||||
optimizer=optimizer,
|
||||
num_warmup_steps=0,
|
||||
num_training_steps=num_training)
|
||||
best_test_loss = float("inf")
|
||||
model.to("cuda")
|
||||
model.train()
|
||||
for epoch in range(num_epoch):
|
||||
progress_bar.set_description(f"Rank {get_current_device()} epoch {epoch}")
|
||||
|
||||
for batch in train_dataloader:
|
||||
optimizer.zero_grad()
|
||||
batch = {k: v.to('cuda') for k, v in batch.items()}
|
||||
|
@ -67,6 +73,7 @@ def train(model: nn.Module, num_epoch: int = 2):
|
|||
loss = outputs.loss
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
progress_bar.update(1)
|
||||
train_loss = loss
|
||||
|
||||
|
@ -75,16 +82,20 @@ def train(model: nn.Module, num_epoch: int = 2):
|
|||
batch = {k: v.to('cuda') for k, v in batch.items()}
|
||||
outputs = model(**batch)
|
||||
# loss = outputs.loss
|
||||
assert not torch.isnan(outputs.loss), f"{batch}"
|
||||
loss += outputs.loss.item()
|
||||
# loss = criterion(outputs.logits, batch["input_ids"])
|
||||
test_loss = loss / len(eval_dataloader)
|
||||
print_rank_0(f"Train Loss: {train_loss:.4f} Test Loss:{test_loss:.4f}")
|
||||
if args.save_model and test_loss < best_test_loss:
|
||||
best_test_loss = test_loss
|
||||
torch.save(model.state_dict(), "./checkpoints/best_model.pth")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = get_args()
|
||||
colossalai.launch_from_torch(config=args.config)
|
||||
model = BertForMaskedLM.from_pretrained("bert-base-uncased")
|
||||
colossalai.launch_from_torch(config=args.config)
|
||||
shard_config = ShardConfig(
|
||||
rank=int(str(get_current_device()).split(':')[-1]),
|
||||
world_size=int(os.environ['WORLD_SIZE']),
|
||||
|
@ -92,6 +103,8 @@ if __name__ == "__main__":
|
|||
sharded_model = shard_model(model, shard_config)
|
||||
|
||||
if args.mode == "train":
|
||||
train(sharded_model)
|
||||
train(sharded_model, args)
|
||||
elif args.mode == "inference":
|
||||
inference(sharded_model)
|
||||
inference(sharded_model, args)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
|
Loading…
Reference in New Issue