Browse Source

[shardformer] refactored the user api (#3828)

* [shardformer] refactored the user api

* polish code
pull/3943/head
Frank Lee 2 years ago committed by FrankLeeeee
parent
commit
537a52b7a2
  1. 6
      colossalai/shardformer/README.md
  2. 5
      colossalai/shardformer/shard/__init__.py
  3. 2
      colossalai/shardformer/shard/shard_config.py
  4. 27
      colossalai/shardformer/shard/sharder.py
  5. 60
      colossalai/shardformer/shard/shardmodel.py
  6. 7
      colossalai/shardformer/shard/slicer.py
  7. 15
      colossalai/shardformer/test/test.py

6
colossalai/shardformer/README.md

@ -18,7 +18,7 @@
The sample API usage is given below: The sample API usage is given below:
``` python ``` python
from colossalai.shardformer.shard.shardmodel import ShardModel from colossalai.shardformer import shard_model
from transformers import BertForMaskedLM from transformers import BertForMaskedLM
# create huggingface model as normal # create huggingface model as normal
@ -26,11 +26,11 @@ model = BertForMaskedLM.from_pretrained("bert-base-uncased")
# make the huggingface model paralleled to ShardModel # make the huggingface model paralleled to ShardModel
# auto policy: # auto policy:
shardmodel = ShardModel(model).model sharded_model = shard_model(model)
# custom policy: # custom policy:
from xxx import <POLICYCLASS> from xxx import <POLICYCLASS>
shardmodel = ShardModel(model, <POLICYCLASS>).model sharded_model = shard_model(model, <POLICYCLASS>)
# do angthing as normal # do angthing as normal
... ...

5
colossalai/shardformer/shard/__init__.py

@ -0,0 +1,5 @@
from .shard_config import ShardConfig
from .sharder import ModelSharder, shard_model
from .slicer import Slicer
__all__ = ['ShardConfig', 'ModelSharder', 'shard_model', 'Slicer']

2
colossalai/shardformer/shard/shardconfig.py → colossalai/shardformer/shard/shard_config.py

@ -1,5 +1,7 @@
from dataclasses import dataclass from dataclasses import dataclass
__all__ = ['ShardConfig']
@dataclass @dataclass
class ShardConfig: class ShardConfig:

27
colossalai/shardformer/shard/sharder.py

@ -1,20 +1,15 @@
import os from typing import Any, Callable, Dict, List
from dataclasses import dataclass
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
import colossalai.nn as col_nn
from colossalai.logging import get_dist_logger
from ..policies.autopolicy import get_autopolicy from ..policies.autopolicy import get_autopolicy
from ..policies.basepolicy import Layer, Policy from ..policies.basepolicy import Policy
from ..utils.utils import getattr_, hasattr_, setattr_ from ..utils.utils import getattr_, hasattr_, setattr_
from .shardconfig import ShardConfig from .shard_config import ShardConfig
from .slicer import Slicer from .slicer import Slicer
logger = get_dist_logger() __all__ = ['ModelSharder', 'shard_model']
class ModelSharder(object): class ModelSharder(object):
@ -245,3 +240,17 @@ class ModelSharder(object):
param = nn.Parameter(param) param = nn.Parameter(param)
setattr_(model, k, param) setattr_(model, k, param)
setattr_(model, v, param) setattr_(model, v, param)
def shard_model(model: nn.Module, shard_config: ShardConfig = None, policy: Policy = None):
r"""
The function is used to shard the PyTorch model.
Args:
model (`torch.nn.Model`): the origin huggingface model
shard_config (`ShardConfig`): the config for distribute information
policy (`Policy`): the custom policy for sharding
"""
sharder = ModelSharder(model=model, shard_config=shard_config, policy=policy)
sharder.shard()
return model

60
colossalai/shardformer/shard/shardmodel.py

@ -1,60 +0,0 @@
import os
from contextlib import suppress
from dataclasses import dataclass
import torch
import torch.distributed as dist
import torch.nn as nn
import transformers
from colossalai.tensor.d_tensor.layout import Layout
from ..policies.basepolicy import Policy
from .shardconfig import ShardConfig
from .sharder import ModelSharder
class ShardModel(object):
r"""
The class for sharding the huggingface model, ''self.model'' is the sharded model
Just creat a new ShardModel object to shard huggingface model
Args:
model (:class:`torch.nn.Model`): the origin huggingface model
dist_config (:class:`ShardConfig`): the config for distribute information
custom_policy (:class:`Policy`): the custom policy for sharding
"""
def __init__(
self,
model: nn.Module,
shard_config: ShardConfig = None, # TODO
custom_policy: Policy = None,
) -> None:
self.model = model
self.shard_config = shard_config
self.policy = custom_policy
# self.layout=, # TODO
sharder = ModelSharder(
model=self.model,
policy=self.policy,
shard_config=self.shard_config,
)
sharder.shard()
def set_environ(self) -> None:
os.environ["TOKENIZERS_PARALLELISM"] = "true"
os.environ["MKL_SERVICE_FORCE_INTEL"] = "GNU"
os.environ["MASTER_ADDR"] = str(self.dist_config.master_addr)
os.environ["MASTER_PORT"] = str(self.dist_config.master_port)
os.environ["WORLD_SIZE"] = str(self.dist_config.num_gpus)
os.environ["RANK"] = str(self.dist_config.rank)
os.environ["LOCAL_RANK"] = str(self.dist_config.rank)
if not dist.is_initialized():
dist.init_process_group(backend=self.dist_config.backend)
torch.cuda.set_device(int(os.getenv("LOCAL_RANK", "0")))
def back_to_org() -> None:
pass

7
colossalai/shardformer/shard/slicer.py

@ -1,12 +1,7 @@
import os
from dataclasses import dataclass
from typing import Dict, Tuple
import torch import torch
import torch.distributed as dist
from ..policies.basepolicy import Col_Layer, Layer, Row_Layer from ..policies.basepolicy import Col_Layer, Layer, Row_Layer
from .shardconfig import ShardConfig from .shard_config import ShardConfig
dim_mapping = {Col_Layer: 1, Row_Layer: 0} dim_mapping = {Col_Layer: 1, Row_Layer: 0}

15
colossalai/shardformer/test/test.py

@ -1,5 +1,3 @@
import argparse
import inspect
import os import os
import torch import torch
@ -7,12 +5,10 @@ import torch.nn as nn
from datasets import load_dataset from datasets import load_dataset
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from tqdm.auto import tqdm from tqdm.auto import tqdm
from transformers import AutoTokenizer, BertForMaskedLM, DataCollatorForLanguageModeling, Trainer, TrainingArguments from transformers import AutoTokenizer, BertForMaskedLM, DataCollatorForLanguageModeling
import colossalai import colossalai
from colossalai.logging import get_dist_logger from colossalai.shardformer.shard import ShardConfig, shard_model
from colossalai.shardformer.shard.shardconfig import ShardConfig
from colossalai.shardformer.shard.shardmodel import ShardModel
from colossalai.utils import get_current_device, print_rank_0 from colossalai.utils import get_current_device, print_rank_0
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true' os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true'
@ -93,8 +89,9 @@ if __name__ == "__main__":
rank=int(str(get_current_device()).split(':')[-1]), rank=int(str(get_current_device()).split(':')[-1]),
world_size=int(os.environ['WORLD_SIZE']), world_size=int(os.environ['WORLD_SIZE']),
) )
shardmodel = ShardModel(model, shard_config) sharded_model = shard_model(model, shard_config)
if args.mode == "train": if args.mode == "train":
train(shardmodel.model) train(sharded_model)
elif args.mode == "inference": elif args.mode == "inference":
inference(shardmodel.model) inference(sharded_model)

Loading…
Cancel
Save