mirror of https://github.com/hpcaitech/ColossalAI
[shardformer] refactored the user api (#3828)
* [shardformer] refactored the user api * polish codepull/3943/head
parent
bc19024bf9
commit
537a52b7a2
|
@ -18,7 +18,7 @@
|
||||||
The sample API usage is given below:
|
The sample API usage is given below:
|
||||||
|
|
||||||
``` python
|
``` python
|
||||||
from colossalai.shardformer.shard.shardmodel import ShardModel
|
from colossalai.shardformer import shard_model
|
||||||
from transformers import BertForMaskedLM
|
from transformers import BertForMaskedLM
|
||||||
|
|
||||||
# create huggingface model as normal
|
# create huggingface model as normal
|
||||||
|
@ -26,11 +26,11 @@ model = BertForMaskedLM.from_pretrained("bert-base-uncased")
|
||||||
|
|
||||||
# make the huggingface model paralleled to ShardModel
|
# make the huggingface model paralleled to ShardModel
|
||||||
# auto policy:
|
# auto policy:
|
||||||
shardmodel = ShardModel(model).model
|
sharded_model = shard_model(model)
|
||||||
|
|
||||||
# custom policy:
|
# custom policy:
|
||||||
from xxx import <POLICYCLASS>
|
from xxx import <POLICYCLASS>
|
||||||
shardmodel = ShardModel(model, <POLICYCLASS>).model
|
sharded_model = shard_model(model, <POLICYCLASS>)
|
||||||
|
|
||||||
# do angthing as normal
|
# do angthing as normal
|
||||||
...
|
...
|
||||||
|
|
|
@ -0,0 +1,5 @@
|
||||||
|
from .shard_config import ShardConfig
|
||||||
|
from .sharder import ModelSharder, shard_model
|
||||||
|
from .slicer import Slicer
|
||||||
|
|
||||||
|
__all__ = ['ShardConfig', 'ModelSharder', 'shard_model', 'Slicer']
|
|
@ -1,5 +1,7 @@
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
__all__ = ['ShardConfig']
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ShardConfig:
|
class ShardConfig:
|
|
@ -1,20 +1,15 @@
|
||||||
import os
|
from typing import Any, Callable, Dict, List
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type, Union
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
import colossalai.nn as col_nn
|
|
||||||
from colossalai.logging import get_dist_logger
|
|
||||||
|
|
||||||
from ..policies.autopolicy import get_autopolicy
|
from ..policies.autopolicy import get_autopolicy
|
||||||
from ..policies.basepolicy import Layer, Policy
|
from ..policies.basepolicy import Policy
|
||||||
from ..utils.utils import getattr_, hasattr_, setattr_
|
from ..utils.utils import getattr_, hasattr_, setattr_
|
||||||
from .shardconfig import ShardConfig
|
from .shard_config import ShardConfig
|
||||||
from .slicer import Slicer
|
from .slicer import Slicer
|
||||||
|
|
||||||
logger = get_dist_logger()
|
__all__ = ['ModelSharder', 'shard_model']
|
||||||
|
|
||||||
|
|
||||||
class ModelSharder(object):
|
class ModelSharder(object):
|
||||||
|
@ -245,3 +240,17 @@ class ModelSharder(object):
|
||||||
param = nn.Parameter(param)
|
param = nn.Parameter(param)
|
||||||
setattr_(model, k, param)
|
setattr_(model, k, param)
|
||||||
setattr_(model, v, param)
|
setattr_(model, v, param)
|
||||||
|
|
||||||
|
|
||||||
|
def shard_model(model: nn.Module, shard_config: ShardConfig = None, policy: Policy = None):
|
||||||
|
r"""
|
||||||
|
The function is used to shard the PyTorch model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (`torch.nn.Model`): the origin huggingface model
|
||||||
|
shard_config (`ShardConfig`): the config for distribute information
|
||||||
|
policy (`Policy`): the custom policy for sharding
|
||||||
|
"""
|
||||||
|
sharder = ModelSharder(model=model, shard_config=shard_config, policy=policy)
|
||||||
|
sharder.shard()
|
||||||
|
return model
|
||||||
|
|
|
@ -1,60 +0,0 @@
|
||||||
import os
|
|
||||||
from contextlib import suppress
|
|
||||||
from dataclasses import dataclass
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.distributed as dist
|
|
||||||
import torch.nn as nn
|
|
||||||
import transformers
|
|
||||||
|
|
||||||
from colossalai.tensor.d_tensor.layout import Layout
|
|
||||||
|
|
||||||
from ..policies.basepolicy import Policy
|
|
||||||
from .shardconfig import ShardConfig
|
|
||||||
from .sharder import ModelSharder
|
|
||||||
|
|
||||||
|
|
||||||
class ShardModel(object):
|
|
||||||
r"""
|
|
||||||
The class for sharding the huggingface model, ''self.model'' is the sharded model
|
|
||||||
Just creat a new ShardModel object to shard huggingface model
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model (:class:`torch.nn.Model`): the origin huggingface model
|
|
||||||
dist_config (:class:`ShardConfig`): the config for distribute information
|
|
||||||
custom_policy (:class:`Policy`): the custom policy for sharding
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model: nn.Module,
|
|
||||||
shard_config: ShardConfig = None, # TODO
|
|
||||||
custom_policy: Policy = None,
|
|
||||||
) -> None:
|
|
||||||
self.model = model
|
|
||||||
self.shard_config = shard_config
|
|
||||||
self.policy = custom_policy
|
|
||||||
# self.layout=, # TODO
|
|
||||||
|
|
||||||
sharder = ModelSharder(
|
|
||||||
model=self.model,
|
|
||||||
policy=self.policy,
|
|
||||||
shard_config=self.shard_config,
|
|
||||||
)
|
|
||||||
sharder.shard()
|
|
||||||
|
|
||||||
def set_environ(self) -> None:
|
|
||||||
os.environ["TOKENIZERS_PARALLELISM"] = "true"
|
|
||||||
os.environ["MKL_SERVICE_FORCE_INTEL"] = "GNU"
|
|
||||||
os.environ["MASTER_ADDR"] = str(self.dist_config.master_addr)
|
|
||||||
os.environ["MASTER_PORT"] = str(self.dist_config.master_port)
|
|
||||||
os.environ["WORLD_SIZE"] = str(self.dist_config.num_gpus)
|
|
||||||
os.environ["RANK"] = str(self.dist_config.rank)
|
|
||||||
os.environ["LOCAL_RANK"] = str(self.dist_config.rank)
|
|
||||||
if not dist.is_initialized():
|
|
||||||
dist.init_process_group(backend=self.dist_config.backend)
|
|
||||||
|
|
||||||
torch.cuda.set_device(int(os.getenv("LOCAL_RANK", "0")))
|
|
||||||
|
|
||||||
def back_to_org() -> None:
|
|
||||||
pass
|
|
|
@ -1,12 +1,7 @@
|
||||||
import os
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import Dict, Tuple
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
|
||||||
|
|
||||||
from ..policies.basepolicy import Col_Layer, Layer, Row_Layer
|
from ..policies.basepolicy import Col_Layer, Layer, Row_Layer
|
||||||
from .shardconfig import ShardConfig
|
from .shard_config import ShardConfig
|
||||||
|
|
||||||
dim_mapping = {Col_Layer: 1, Row_Layer: 0}
|
dim_mapping = {Col_Layer: 1, Row_Layer: 0}
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,3 @@
|
||||||
import argparse
|
|
||||||
import inspect
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
@ -7,12 +5,10 @@ import torch.nn as nn
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
from transformers import AutoTokenizer, BertForMaskedLM, DataCollatorForLanguageModeling, Trainer, TrainingArguments
|
from transformers import AutoTokenizer, BertForMaskedLM, DataCollatorForLanguageModeling
|
||||||
|
|
||||||
import colossalai
|
import colossalai
|
||||||
from colossalai.logging import get_dist_logger
|
from colossalai.shardformer.shard import ShardConfig, shard_model
|
||||||
from colossalai.shardformer.shard.shardconfig import ShardConfig
|
|
||||||
from colossalai.shardformer.shard.shardmodel import ShardModel
|
|
||||||
from colossalai.utils import get_current_device, print_rank_0
|
from colossalai.utils import get_current_device, print_rank_0
|
||||||
|
|
||||||
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true'
|
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true'
|
||||||
|
@ -93,8 +89,9 @@ if __name__ == "__main__":
|
||||||
rank=int(str(get_current_device()).split(':')[-1]),
|
rank=int(str(get_current_device()).split(':')[-1]),
|
||||||
world_size=int(os.environ['WORLD_SIZE']),
|
world_size=int(os.environ['WORLD_SIZE']),
|
||||||
)
|
)
|
||||||
shardmodel = ShardModel(model, shard_config)
|
sharded_model = shard_model(model, shard_config)
|
||||||
|
|
||||||
if args.mode == "train":
|
if args.mode == "train":
|
||||||
train(shardmodel.model)
|
train(sharded_model)
|
||||||
elif args.mode == "inference":
|
elif args.mode == "inference":
|
||||||
inference(shardmodel.model)
|
inference(sharded_model)
|
||||||
|
|
Loading…
Reference in New Issue