mirror of https://github.com/hpcaitech/ColossalAI
[shardformer] refactored the user api (#3828)
* [shardformer] refactored the user api * polish codepull/4157/head
parent
235792f170
commit
4972e1f40e
@ -0,0 +1,5 @@
|
||||
from .shard_config import ShardConfig
|
||||
from .sharder import ModelSharder, shard_model
|
||||
from .slicer import Slicer
|
||||
|
||||
__all__ = ['ShardConfig', 'ModelSharder', 'shard_model', 'Slicer']
|
@ -1,5 +1,7 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
__all__ = ['ShardConfig']
|
||||
|
||||
|
||||
@dataclass
|
||||
class ShardConfig:
|
@ -1,60 +0,0 @@
|
||||
import os
|
||||
from contextlib import suppress
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
import transformers
|
||||
|
||||
from colossalai.tensor.d_tensor.layout import Layout
|
||||
|
||||
from ..policies.basepolicy import Policy
|
||||
from .shardconfig import ShardConfig
|
||||
from .sharder import ModelSharder
|
||||
|
||||
|
||||
class ShardModel(object):
|
||||
r"""
|
||||
The class for sharding the huggingface model, ''self.model'' is the sharded model
|
||||
Just creat a new ShardModel object to shard huggingface model
|
||||
|
||||
Args:
|
||||
model (:class:`torch.nn.Model`): the origin huggingface model
|
||||
dist_config (:class:`ShardConfig`): the config for distribute information
|
||||
custom_policy (:class:`Policy`): the custom policy for sharding
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: nn.Module,
|
||||
shard_config: ShardConfig = None, # TODO
|
||||
custom_policy: Policy = None,
|
||||
) -> None:
|
||||
self.model = model
|
||||
self.shard_config = shard_config
|
||||
self.policy = custom_policy
|
||||
# self.layout=, # TODO
|
||||
|
||||
sharder = ModelSharder(
|
||||
model=self.model,
|
||||
policy=self.policy,
|
||||
shard_config=self.shard_config,
|
||||
)
|
||||
sharder.shard()
|
||||
|
||||
def set_environ(self) -> None:
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "true"
|
||||
os.environ["MKL_SERVICE_FORCE_INTEL"] = "GNU"
|
||||
os.environ["MASTER_ADDR"] = str(self.dist_config.master_addr)
|
||||
os.environ["MASTER_PORT"] = str(self.dist_config.master_port)
|
||||
os.environ["WORLD_SIZE"] = str(self.dist_config.num_gpus)
|
||||
os.environ["RANK"] = str(self.dist_config.rank)
|
||||
os.environ["LOCAL_RANK"] = str(self.dist_config.rank)
|
||||
if not dist.is_initialized():
|
||||
dist.init_process_group(backend=self.dist_config.backend)
|
||||
|
||||
torch.cuda.set_device(int(os.getenv("LOCAL_RANK", "0")))
|
||||
|
||||
def back_to_org() -> None:
|
||||
pass
|
Loading…
Reference in new issue