[shardformer] refactored the user api (#3828)

* [shardformer] refactored the user api

* polish code
pull/4157/head
Frank Lee 2 years ago
parent 235792f170
commit 4972e1f40e

@ -18,7 +18,7 @@
The sample API usage is given below:
``` python
from colossalai.shardformer.shard.shardmodel import ShardModel
from colossalai.shardformer import shard_model
from transformers import BertForMaskedLM
# create huggingface model as normal
@ -26,11 +26,11 @@ model = BertForMaskedLM.from_pretrained("bert-base-uncased")
# make the huggingface model paralleled to ShardModel
# auto policy:
shardmodel = ShardModel(model).model
sharded_model = shard_model(model)
# custom policy:
from xxx import <POLICYCLASS>
shardmodel = ShardModel(model, <POLICYCLASS>).model
sharded_model = shard_model(model, <POLICYCLASS>)
# do angthing as normal
...

@ -0,0 +1,5 @@
from .shard_config import ShardConfig
from .sharder import ModelSharder, shard_model
from .slicer import Slicer
__all__ = ['ShardConfig', 'ModelSharder', 'shard_model', 'Slicer']

@ -1,5 +1,7 @@
from dataclasses import dataclass
__all__ = ['ShardConfig']
@dataclass
class ShardConfig:

@ -1,20 +1,15 @@
import os
from dataclasses import dataclass
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type, Union
from typing import Any, Callable, Dict, List
import torch
import torch.nn as nn
import colossalai.nn as col_nn
from colossalai.logging import get_dist_logger
from ..policies.autopolicy import get_autopolicy
from ..policies.basepolicy import Layer, Policy
from ..policies.basepolicy import Policy
from ..utils.utils import getattr_, hasattr_, setattr_
from .shardconfig import ShardConfig
from .shard_config import ShardConfig
from .slicer import Slicer
logger = get_dist_logger()
__all__ = ['ModelSharder', 'shard_model']
class ModelSharder(object):
@ -245,3 +240,17 @@ class ModelSharder(object):
param = nn.Parameter(param)
setattr_(model, k, param)
setattr_(model, v, param)
def shard_model(model: nn.Module, shard_config: ShardConfig = None, policy: Policy = None):
r"""
The function is used to shard the PyTorch model.
Args:
model (`torch.nn.Model`): the origin huggingface model
shard_config (`ShardConfig`): the config for distribute information
policy (`Policy`): the custom policy for sharding
"""
sharder = ModelSharder(model=model, shard_config=shard_config, policy=policy)
sharder.shard()
return model

@ -1,60 +0,0 @@
import os
from contextlib import suppress
from dataclasses import dataclass
import torch
import torch.distributed as dist
import torch.nn as nn
import transformers
from colossalai.tensor.d_tensor.layout import Layout
from ..policies.basepolicy import Policy
from .shardconfig import ShardConfig
from .sharder import ModelSharder
class ShardModel(object):
r"""
The class for sharding the huggingface model, ''self.model'' is the sharded model
Just creat a new ShardModel object to shard huggingface model
Args:
model (:class:`torch.nn.Model`): the origin huggingface model
dist_config (:class:`ShardConfig`): the config for distribute information
custom_policy (:class:`Policy`): the custom policy for sharding
"""
def __init__(
self,
model: nn.Module,
shard_config: ShardConfig = None, # TODO
custom_policy: Policy = None,
) -> None:
self.model = model
self.shard_config = shard_config
self.policy = custom_policy
# self.layout=, # TODO
sharder = ModelSharder(
model=self.model,
policy=self.policy,
shard_config=self.shard_config,
)
sharder.shard()
def set_environ(self) -> None:
os.environ["TOKENIZERS_PARALLELISM"] = "true"
os.environ["MKL_SERVICE_FORCE_INTEL"] = "GNU"
os.environ["MASTER_ADDR"] = str(self.dist_config.master_addr)
os.environ["MASTER_PORT"] = str(self.dist_config.master_port)
os.environ["WORLD_SIZE"] = str(self.dist_config.num_gpus)
os.environ["RANK"] = str(self.dist_config.rank)
os.environ["LOCAL_RANK"] = str(self.dist_config.rank)
if not dist.is_initialized():
dist.init_process_group(backend=self.dist_config.backend)
torch.cuda.set_device(int(os.getenv("LOCAL_RANK", "0")))
def back_to_org() -> None:
pass

@ -1,12 +1,7 @@
import os
from dataclasses import dataclass
from typing import Dict, Tuple
import torch
import torch.distributed as dist
from ..policies.basepolicy import Col_Layer, Layer, Row_Layer
from .shardconfig import ShardConfig
from .shard_config import ShardConfig
dim_mapping = {Col_Layer: 1, Row_Layer: 0}

@ -1,5 +1,3 @@
import argparse
import inspect
import os
import torch
@ -7,12 +5,10 @@ import torch.nn as nn
from datasets import load_dataset
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from transformers import AutoTokenizer, BertForMaskedLM, DataCollatorForLanguageModeling, Trainer, TrainingArguments
from transformers import AutoTokenizer, BertForMaskedLM, DataCollatorForLanguageModeling
import colossalai
from colossalai.logging import get_dist_logger
from colossalai.shardformer.shard.shardconfig import ShardConfig
from colossalai.shardformer.shard.shardmodel import ShardModel
from colossalai.shardformer.shard import ShardConfig, shard_model
from colossalai.utils import get_current_device, print_rank_0
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true'
@ -93,8 +89,9 @@ if __name__ == "__main__":
rank=int(str(get_current_device()).split(':')[-1]),
world_size=int(os.environ['WORLD_SIZE']),
)
shardmodel = ShardModel(model, shard_config)
sharded_model = shard_model(model, shard_config)
if args.mode == "train":
train(shardmodel.model)
train(sharded_model)
elif args.mode == "inference":
inference(shardmodel.model)
inference(sharded_model)

Loading…
Cancel
Save