diff --git a/colossalai/shardformer/README.md b/colossalai/shardformer/README.md index f76cbac8d..10fd1809b 100644 --- a/colossalai/shardformer/README.md +++ b/colossalai/shardformer/README.md @@ -18,7 +18,7 @@ The sample API usage is given below: ``` python -from colossalai.shardformer.shard.shardmodel import ShardModel +from colossalai.shardformer import shard_model from transformers import BertForMaskedLM # create huggingface model as normal @@ -26,11 +26,11 @@ model = BertForMaskedLM.from_pretrained("bert-base-uncased") # make the huggingface model paralleled to ShardModel # auto policy: -shardmodel = ShardModel(model).model +sharded_model = shard_model(model) # custom policy: from xxx import -shardmodel = ShardModel(model, ).model +sharded_model = shard_model(model, ) # do angthing as normal ... diff --git a/colossalai/shardformer/shard/__init__.py b/colossalai/shardformer/shard/__init__.py index e69de29bb..d5f70163a 100644 --- a/colossalai/shardformer/shard/__init__.py +++ b/colossalai/shardformer/shard/__init__.py @@ -0,0 +1,5 @@ +from .shard_config import ShardConfig +from .sharder import ModelSharder, shard_model +from .slicer import Slicer + +__all__ = ['ShardConfig', 'ModelSharder', 'shard_model', 'Slicer'] diff --git a/colossalai/shardformer/shard/shardconfig.py b/colossalai/shardformer/shard/shard_config.py similarity index 93% rename from colossalai/shardformer/shard/shardconfig.py rename to colossalai/shardformer/shard/shard_config.py index c6a2513a6..4cf9162b9 100644 --- a/colossalai/shardformer/shard/shardconfig.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -1,5 +1,7 @@ from dataclasses import dataclass +__all__ = ['ShardConfig'] + @dataclass class ShardConfig: diff --git a/colossalai/shardformer/shard/sharder.py b/colossalai/shardformer/shard/sharder.py index 2f6bb4265..221866188 100644 --- a/colossalai/shardformer/shard/sharder.py +++ b/colossalai/shardformer/shard/sharder.py @@ -1,20 +1,15 @@ -import os -from dataclasses import dataclass -from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type, Union +from typing import Any, Callable, Dict, List import torch import torch.nn as nn -import colossalai.nn as col_nn -from colossalai.logging import get_dist_logger - from ..policies.autopolicy import get_autopolicy -from ..policies.basepolicy import Layer, Policy +from ..policies.basepolicy import Policy from ..utils.utils import getattr_, hasattr_, setattr_ -from .shardconfig import ShardConfig +from .shard_config import ShardConfig from .slicer import Slicer -logger = get_dist_logger() +__all__ = ['ModelSharder', 'shard_model'] class ModelSharder(object): @@ -245,3 +240,17 @@ class ModelSharder(object): param = nn.Parameter(param) setattr_(model, k, param) setattr_(model, v, param) + + +def shard_model(model: nn.Module, shard_config: ShardConfig = None, policy: Policy = None): + r""" + The function is used to shard the PyTorch model. + + Args: + model (`torch.nn.Model`): the origin huggingface model + shard_config (`ShardConfig`): the config for distribute information + policy (`Policy`): the custom policy for sharding + """ + sharder = ModelSharder(model=model, shard_config=shard_config, policy=policy) + sharder.shard() + return model diff --git a/colossalai/shardformer/shard/shardmodel.py b/colossalai/shardformer/shard/shardmodel.py deleted file mode 100644 index 7e7d1576a..000000000 --- a/colossalai/shardformer/shard/shardmodel.py +++ /dev/null @@ -1,60 +0,0 @@ -import os -from contextlib import suppress -from dataclasses import dataclass - -import torch -import torch.distributed as dist -import torch.nn as nn -import transformers - -from colossalai.tensor.d_tensor.layout import Layout - -from ..policies.basepolicy import Policy -from .shardconfig import ShardConfig -from .sharder import ModelSharder - - -class ShardModel(object): - r""" - The class for sharding the huggingface model, ''self.model'' is the sharded model - Just creat a new ShardModel object to shard huggingface model - - Args: - model (:class:`torch.nn.Model`): the origin huggingface model - dist_config (:class:`ShardConfig`): the config for distribute information - custom_policy (:class:`Policy`): the custom policy for sharding - """ - - def __init__( - self, - model: nn.Module, - shard_config: ShardConfig = None, # TODO - custom_policy: Policy = None, - ) -> None: - self.model = model - self.shard_config = shard_config - self.policy = custom_policy - # self.layout=, # TODO - - sharder = ModelSharder( - model=self.model, - policy=self.policy, - shard_config=self.shard_config, - ) - sharder.shard() - - def set_environ(self) -> None: - os.environ["TOKENIZERS_PARALLELISM"] = "true" - os.environ["MKL_SERVICE_FORCE_INTEL"] = "GNU" - os.environ["MASTER_ADDR"] = str(self.dist_config.master_addr) - os.environ["MASTER_PORT"] = str(self.dist_config.master_port) - os.environ["WORLD_SIZE"] = str(self.dist_config.num_gpus) - os.environ["RANK"] = str(self.dist_config.rank) - os.environ["LOCAL_RANK"] = str(self.dist_config.rank) - if not dist.is_initialized(): - dist.init_process_group(backend=self.dist_config.backend) - - torch.cuda.set_device(int(os.getenv("LOCAL_RANK", "0"))) - - def back_to_org() -> None: - pass diff --git a/colossalai/shardformer/shard/slicer.py b/colossalai/shardformer/shard/slicer.py index 096f5db95..957ce1f85 100644 --- a/colossalai/shardformer/shard/slicer.py +++ b/colossalai/shardformer/shard/slicer.py @@ -1,12 +1,7 @@ -import os -from dataclasses import dataclass -from typing import Dict, Tuple - import torch -import torch.distributed as dist from ..policies.basepolicy import Col_Layer, Layer, Row_Layer -from .shardconfig import ShardConfig +from .shard_config import ShardConfig dim_mapping = {Col_Layer: 1, Row_Layer: 0} diff --git a/colossalai/shardformer/test/test.py b/colossalai/shardformer/test/test.py index 0cdc6ef38..202208123 100644 --- a/colossalai/shardformer/test/test.py +++ b/colossalai/shardformer/test/test.py @@ -1,5 +1,3 @@ -import argparse -import inspect import os import torch @@ -7,12 +5,10 @@ import torch.nn as nn from datasets import load_dataset from torch.utils.data import DataLoader from tqdm.auto import tqdm -from transformers import AutoTokenizer, BertForMaskedLM, DataCollatorForLanguageModeling, Trainer, TrainingArguments +from transformers import AutoTokenizer, BertForMaskedLM, DataCollatorForLanguageModeling import colossalai -from colossalai.logging import get_dist_logger -from colossalai.shardformer.shard.shardconfig import ShardConfig -from colossalai.shardformer.shard.shardmodel import ShardModel +from colossalai.shardformer.shard import ShardConfig, shard_model from colossalai.utils import get_current_device, print_rank_0 os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true' @@ -93,8 +89,9 @@ if __name__ == "__main__": rank=int(str(get_current_device()).split(':')[-1]), world_size=int(os.environ['WORLD_SIZE']), ) - shardmodel = ShardModel(model, shard_config) + sharded_model = shard_model(model, shard_config) + if args.mode == "train": - train(shardmodel.model) + train(sharded_model) elif args.mode == "inference": - inference(shardmodel.model) + inference(sharded_model)