diff --git a/colossalai/shardformer/README.md b/colossalai/shardformer/README.md index 93a4f1e57..222626db3 100644 --- a/colossalai/shardformer/README.md +++ b/colossalai/shardformer/README.md @@ -20,7 +20,7 @@ The sample API usage is given below: ``` python -from colossalai.shardformer import shard_model +from colossalai.shardformer import ShardConfig, shard_model from transformers import BertForMaskedLM # create huggingface model as normal @@ -28,7 +28,12 @@ model = BertForMaskedLM.from_pretrained("bert-base-uncased") # make the huggingface model paralleled to ShardModel # auto policy: -sharded_model = shard_model(model) +shardconfig = ShardConfig( + rank=rank, + world_size=world_size, + gather_output=True, +) +sharded_model = shard_model(model, config=shardconfig) # custom policy: from xxx import @@ -72,7 +77,7 @@ More details can be found in shardformer/policies/basepolicy.py ``` python from colossalai.shardformer.policies.basepolicy import Policy, Layer, Col_Layer, Row_Layer, Argument -CustomPolicy(Policy): +class CustomPolicy(Policy): @staticmethod def argument_policy(model_config, shard_config: int) -> Dict[nn.Module, Argument]: r""" @@ -235,7 +240,7 @@ CustomPolicy(Policy): This class is used to specify the replacement policy for a particular layer. If `replace_layer` is None, only parameter partitioning will be performed without replacing the layer class. CLASS `Col_Layer(Layer)`: - - gather_output (bool): Whether to gather the output of the layer + - gather_output (bool): Whether the output of this layer can be gathered, like the last layer can be gathered, but most of the time, the intermediate layers of the model do not need to be gathered. This class inherited from `Layer`, representing the layer will be sliced along column. diff --git a/colossalai/shardformer/__init__.py b/colossalai/shardformer/__init__.py index e69de29bb..50c927380 100644 --- a/colossalai/shardformer/__init__.py +++ b/colossalai/shardformer/__init__.py @@ -0,0 +1 @@ +from .shard import ShardConfig, shard_model diff --git a/colossalai/shardformer/layer/dist_crossentropy.py b/colossalai/shardformer/layer/dist_crossentropy.py index 186959467..05c04bb54 100644 --- a/colossalai/shardformer/layer/dist_crossentropy.py +++ b/colossalai/shardformer/layer/dist_crossentropy.py @@ -14,7 +14,7 @@ class DistCrossEntropy(Function): """ @staticmethod - def forward(ctx, vocab_logits: torch.Tensor, target: torch.Tensor): + def forward(ctx, vocab_logits: torch.Tensor, target: torch.Tensor, ignore_index: int): r""" Calculate the cross entropy loss before gather, the origin loss function is as follows: loss = -log(exp(x[class])/sum(exp(x[i])) @@ -75,8 +75,8 @@ class DistCrossEntropy(Function): # calculate the loss # loss = log(sum(exp(x[i]))) - x[class] - loss = torch.log(sum_exp_logits) - pred_logits - loss = torch.sum(loss).div_(loss.numel()) + loss = torch.where(target == ignore_index, 0.0, torch.log(sum_exp_logits) - pred_logits) + loss = torch.sum(loss).div_(torch.sum(loss != 0.0)) # caculate the softmax exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1)) @@ -101,5 +101,5 @@ class DistCrossEntropy(Function): return grad_logits, None, None -def applyDistCrossEntropy(vocab_logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: - return DistCrossEntropy.apply(vocab_logits, labels) +def applyDistCrossEntropy(vocab_logits: torch.Tensor, labels: torch.Tensor, ignore_index: int = -100) -> torch.Tensor: + return DistCrossEntropy.apply(vocab_logits, labels, ignore_index) diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index 89b32f065..5d489f419 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -141,7 +141,7 @@ class BertPolicy(Policy): weight="decoder.weight", bias="decoder.bias", replace_layer=col_nn.Linear1D_Col, - # gather_output=True, + gather_output=True, ) ] @@ -155,7 +155,8 @@ class BertForMaskedLMPolicy(BertPolicy): @staticmethod def inject_policy() -> Tuple[nn.Module, nn.Module]: - return (BertForMaskedLM, BertForMaskedLM_) + # return (BertForMaskedLM, BertForMaskedLM_) + return None class BertForSequenceClassificationPolicy(BertPolicy): diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index 4cf9162b9..e8d6f3408 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -5,16 +5,14 @@ __all__ = ['ShardConfig'] @dataclass class ShardConfig: - """ - The config for sharding the huggingface model for test + r""" + The config for sharding the huggingface model + + Args: + rank (int): The rank of local process + world_size (int): The world size of the distributed process + gather_output (bool): Whether to gather the output of the model of the last layer """ rank: int - fp16: bool = True - num_gpus: int = 2 world_size: int = 2 - backend = "nccl" - verbose: str = 'simple' - seed: int = None - require_grad: bool = False - master_addr: str = "127.0.0.1" - master_port: int = 29500 + gather_output: bool = True diff --git a/colossalai/shardformer/shard/sharder.py b/colossalai/shardformer/shard/sharder.py index 1ada75e06..159bebccd 100644 --- a/colossalai/shardformer/shard/sharder.py +++ b/colossalai/shardformer/shard/sharder.py @@ -65,6 +65,8 @@ class ModelSharder(object): BertForMaskedLM.forward -> BertForMaskedLM_.forward """ inject_policy = self.policy.inject_policy() + if inject_policy is None: + return if inject_policy is None: return @@ -148,7 +150,7 @@ class ModelSharder(object): n_cast = policy_layer.n_cast reversed = policy_layer.reversed if policy_layer.__class__.__name__ == "Col_Layer": - gather_output = policy_layer.gather_output + gather_output = policy_layer.gather_output and self.shard_config.gather_output if weight_attr is not None: if hasattr_(org_layer, weight_attr): diff --git a/colossalai/shardformer/test/config.py b/colossalai/shardformer/test/config.py deleted file mode 100644 index 2b80d8b3c..000000000 --- a/colossalai/shardformer/test/config.py +++ /dev/null @@ -1 +0,0 @@ -parallel = dict(data=1, pipeline=1, tensor=dict(size=2, mode='1d')) diff --git a/colossalai/shardformer/test/module_test.py b/colossalai/shardformer/test/module_test.py deleted file mode 100644 index 83dc7ec6c..000000000 --- a/colossalai/shardformer/test/module_test.py +++ /dev/null @@ -1,50 +0,0 @@ -import os - -import torch -import torch.nn as nn -import torch.nn.functional as F - -import colossalai -from colossalai.shardformer.layer.dist_crossentropy import applyDistCrossEntropy -from colossalai.shardformer.layer.dropout import Dropout1D - - -def get_args(): - parser = colossalai.get_default_parser() - parser.add_argument("--module", type=str, default='distloss') - return parser.parse_args() - - -def test_dist_crossentropy(): - pred = torch.randn(2, 4, 8, requires_grad=True) - labels = torch.randint(8, (1, 4)).repeat(2, 1) - - pred_ = pred.view(-1, 8) - labels_ = labels.view(-1) - loss = F.cross_entropy(pred_, labels_) - loss.backward() - print(f"normal loss:{loss}") - - pred = pred.chunk(int(os.environ['WORLD_SIZE']), -1)[int(os.environ['RANK'])] - loss = applyDistCrossEntropy(pred.to('cuda'), labels.to('cuda')) - loss.backward() - print(f"dist loss:{loss}") - - -def test_dropout(): - input = torch.randn(5, 4).to("cuda") - m = Dropout1D(p=0.2).to("cuda") - for i in range(2): - print(f"Output: {m(input)}") - print(torch.randn(1)) - - -if __name__ == '__main__': - args = get_args() - colossalai.launch_from_torch(config={}) - if args.module == 'distloss': - test_dist_crossentropy() - elif args.module == 'dropout': - test_dropout() - else: - print("not implemented yet") diff --git a/colossalai/shardformer/test/test.py b/colossalai/shardformer/test/test.py deleted file mode 100644 index e2d5a94c7..000000000 --- a/colossalai/shardformer/test/test.py +++ /dev/null @@ -1,124 +0,0 @@ -import os -import random - -import torch -import torch.nn as nn -from datasets import load_dataset -from torch.utils.data import DataLoader -from tqdm.auto import tqdm -from transformers import AutoTokenizer, BertForMaskedLM, DataCollatorForLanguageModeling, GPT2LMHeadModel, get_scheduler - -import colossalai -from colossalai.shardformer.shard import ShardConfig, shard_model -from colossalai.utils import get_current_device, print_rank_0 - -os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true' - - -def get_args(): - parser = colossalai.get_default_parser() - parser.add_argument("--mode", type=str, default='inference') - parser.add_argument("--save_model", action='store_true') - parser.add_argument("--model", type=str, default='bert-base-uncased') - return parser.parse_args() - - -def load_data(args): - tokenizer = AutoTokenizer.from_pretrained(args.model) - if tokenizer.pad_token is None: - tokenizer.add_special_tokens({"pad_token": "[PAD]"}) - # tokenizer.pad_token_id = 0 - datasets = load_dataset('wikitext', 'wikitext-2-raw-v1') - # datasets=load_dataset("yelp_review_full") - tokenized_datasets = datasets.map( - lambda examples: tokenizer(examples["text"], truncation=True, padding="max_length"), batched=True) - tokenized_datasets = tokenized_datasets.remove_columns(["text"]) - # tokenized_datasets=tokenized_datasets.rename_column("label","labels") - tokenized_datasets.set_format("torch") - - train_dataset = tokenized_datasets["train"] - test_dataset = tokenized_datasets["test"] - - datacollector = DataCollatorForLanguageModeling(tokenizer, mlm=True, mlm_probability=0.15, return_tensors="pt") - train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True, collate_fn=datacollector) - eval_dataloader = DataLoader(test_dataset, batch_size=16, shuffle=True, collate_fn=datacollector) - return train_dataloader, eval_dataloader - - -def inference(model: nn.Module, args): - print(model) - # print(model.wte.weight.shape) - tokenizer = AutoTokenizer.from_pretrained(args.model) - if tokenizer.pad_token is None: - tokenizer.add_special_tokens({"pad_token": "[PAD]"}) - tokenizer.pad_token_id = 0 - token = "Hello, my dog is cute" - inputs = tokenizer(token, return_tensors="pt") - inputs.to("cuda") - model.eval() - model.to("cuda") - outputs = model(**inputs) - print(outputs[0]) - - -def train(model: nn.Module, args, num_epoch: int = 3): - train_dataloader, eval_dataloader = load_data(args) - optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5) - num_training = num_epoch * len(train_dataloader) - progress_bar = tqdm(range(num_training)) - lr_scheduler = get_scheduler(name="linear", - optimizer=optimizer, - num_warmup_steps=0, - num_training_steps=num_training) - best_test_loss = float("inf") - model.to("cuda") - model.train() - for epoch in range(num_epoch): - progress_bar.set_description(f"Rank {get_current_device()} epoch {epoch}") - for batch in train_dataloader: - optimizer.zero_grad() - batch = {k: v.to('cuda') for k, v in batch.items()} - outputs = model(**batch) - loss = outputs.loss - loss.backward() - optimizer.step() - lr_scheduler.step() - progress_bar.update(1) - train_loss = loss - - loss = 0.0 - for batch in eval_dataloader: - batch = {k: v.to('cuda') for k, v in batch.items()} - outputs = model(**batch) - # loss = outputs.loss - assert not torch.isnan(outputs.loss), f"{batch}" - loss += outputs.loss.item() - # loss = criterion(outputs.logits, batch["input_ids"]) - test_loss = loss / len(eval_dataloader) - print_rank_0(f"Train Loss: {train_loss:.4f} Test Loss:{test_loss:.4f}") - if args.save_model and test_loss < best_test_loss: - best_test_loss = test_loss - torch.save(model.state_dict(), "./checkpoints/best_model.pth") - - -if __name__ == "__main__": - args = get_args() - colossalai.launch_from_torch(config=args.config) - if args.model == 'bert-base-uncased': - model = BertForMaskedLM.from_pretrained("bert-base-uncased") - elif args.model == 'gpt2': - model = GPT2LMHeadModel.from_pretrained("gpt2") - else: - raise AttributeError("model not supported") - shard_config = ShardConfig( - rank=int(str(get_current_device()).split(':')[-1]), - world_size=int(os.environ['WORLD_SIZE']), - ) - sharded_model = shard_model(model, shard_config) - - if args.mode == "train": - train(sharded_model, args) - elif args.mode == "inference": - inference(sharded_model, args) - else: - raise NotImplementedError diff --git a/tests/test_shardformer/test_model/test_shard_bert.py b/tests/test_shardformer/test_model/test_shard_bert.py new file mode 100644 index 000000000..55b78d040 --- /dev/null +++ b/tests/test_shardformer/test_model/test_shard_bert.py @@ -0,0 +1,103 @@ +import os +import random + +import pytest +import torch +from transformers import AutoTokenizer, BertConfig, BertForMaskedLM + +import colossalai +from colossalai.logging import disable_existing_loggers +from colossalai.shardformer.shard import ShardConfig, shard_model +from colossalai.testing import rerun_if_address_is_in_use, spawn + +os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true' +CONFIG = dict(parallel=dict(data=1, pipeline=1, tensor=dict(size=2, mode='1d')),) +tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") + + +def build_model(rank, world_size): + config = BertConfig.from_pretrained('bert-base-uncased') + config.hidden_dropout_prob = 0 + config.attention_probs_dropout_prob = 0 + + org_model = BertForMaskedLM.from_pretrained('bert-base-uncased', config=config).to('cuda') + + shardconfig = ShardConfig( + rank=rank, + world_size=world_size, + gather_output=True, + ) + sharded_model = shard_model(BertForMaskedLM.from_pretrained('bert-base-uncased', config=config), + shardconfig).to('cuda') + + return org_model, sharded_model + + +def check_forward(org_model, sharded_model): + input = 'Hello, my dog is cute' + tokenized_input = tokenizer(input, return_tensors='pt').to('cuda') + + #orgin model + org_model.eval() + org_out = org_model(**tokenized_input) + + #shard model + sharded_model.eval() + shard_out = sharded_model(**tokenized_input) + + assert torch.allclose( + org_out[0], shard_out[0], + atol=1e-5), f"shard model output is not equal to orgin model output\n{org_out[0]}\n{shard_out[0]}" + + +def check_backward(org_model, sharded_model): + # prepare input + input = 'Hello, my dog is cute' + tokenized_input = tokenizer(input, return_tensors='pt').to('cuda') + labels = tokenized_input['input_ids'].clone() + labels[labels == tokenizer.pad_token_id] = -100 + tokenized_input['labels'] = labels + + #orgin model + org_model.train() + org_out = org_model(**tokenized_input) + org_loss = org_out.loss + org_loss.backward() + org_grad = org_model.bert.encoder.layer[0].attention.self.query.weight.grad + + #shard model + sharded_model.train() + shard_out = sharded_model(**tokenized_input) + shard_loss = shard_out.loss + shard_loss.backward() + shard_grad = sharded_model.bert.encoder.layer[0].attention.self.query.weight.grad + + shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)] + shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad) + all_shard_grad = torch.cat(shard_grad_list, dim=0) + + assert torch.allclose(org_loss, shard_loss, + atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}" + assert torch.allclose(org_grad, all_shard_grad, + atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{shard_grad}" + + +def check_bert(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + + org_model, sharded_model = build_model(rank, world_size) + check_forward(org_model, sharded_model) + check_backward(org_model, sharded_model) + + torch.cuda.empty_cache() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_bert(): + spawn(check_bert, 2) + + +if __name__ == "__main__": + test_bert() diff --git a/tests/test_shardformer/test_module/test_distcrossentropy.py b/tests/test_shardformer/test_module/test_distcrossentropy.py new file mode 100644 index 000000000..9a19ec578 --- /dev/null +++ b/tests/test_shardformer/test_module/test_distcrossentropy.py @@ -0,0 +1,42 @@ +import pytest +import torch +import torch.nn.functional as F + +import colossalai +from colossalai.logging import disable_existing_loggers +from colossalai.shardformer.layer.dist_crossentropy import applyDistCrossEntropy +from colossalai.testing import rerun_if_address_is_in_use, spawn + +CONFIG = dict(parallel=dict(data=1, pipeline=1, tensor=dict(size=2, mode='1d')),) + + +def check_dist_crossentropy(rank, world_size, port, ignore_index): + disable_existing_loggers() + colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, port=port, host='localhost', backend='nccl') + + # prepare data + pred = torch.randn(2, 4, 8, requires_grad=True) + labels = torch.randint(8, (2, 4)) + # set some label to -100 to test the ignore index + labels[0, -1] = ignore_index + + org_pred = pred.view(-1, 8) + org_labels = labels.view(-1) + org_loss = F.cross_entropy(org_pred, org_labels) + + dist_pred = pred.chunk(world_size, -1)[rank] + dist_loss = applyDistCrossEntropy(dist_pred.to('cuda'), labels.to('cuda'), ignore_index=ignore_index) + + assert torch.allclose(org_loss, dist_loss, + atol=1e-5), f"dist cross entropy loss is not equal to orgin loss\n{org_loss}\n{dist_loss}" + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_dist_crossentropy(): + ignore_index = -100 + spawn(check_dist_crossentropy, 2, ignore_index=ignore_index) + + +if __name__ == '__main__': + test_dist_crossentropy()