mirror of https://github.com/hpcaitech/ColossalAI
Frank Lee
2 years ago
committed by
FrankLeeeee
7 changed files with 35 additions and 87 deletions
@ -0,0 +1,5 @@ |
|||||||
|
from .shard_config import ShardConfig |
||||||
|
from .sharder import ModelSharder, shard_model |
||||||
|
from .slicer import Slicer |
||||||
|
|
||||||
|
__all__ = ['ShardConfig', 'ModelSharder', 'shard_model', 'Slicer'] |
@ -1,5 +1,7 @@ |
|||||||
from dataclasses import dataclass |
from dataclasses import dataclass |
||||||
|
|
||||||
|
__all__ = ['ShardConfig'] |
||||||
|
|
||||||
|
|
||||||
@dataclass |
@dataclass |
||||||
class ShardConfig: |
class ShardConfig: |
@ -1,60 +0,0 @@ |
|||||||
import os |
|
||||||
from contextlib import suppress |
|
||||||
from dataclasses import dataclass |
|
||||||
|
|
||||||
import torch |
|
||||||
import torch.distributed as dist |
|
||||||
import torch.nn as nn |
|
||||||
import transformers |
|
||||||
|
|
||||||
from colossalai.tensor.d_tensor.layout import Layout |
|
||||||
|
|
||||||
from ..policies.basepolicy import Policy |
|
||||||
from .shardconfig import ShardConfig |
|
||||||
from .sharder import ModelSharder |
|
||||||
|
|
||||||
|
|
||||||
class ShardModel(object): |
|
||||||
r""" |
|
||||||
The class for sharding the huggingface model, ''self.model'' is the sharded model |
|
||||||
Just creat a new ShardModel object to shard huggingface model |
|
||||||
|
|
||||||
Args: |
|
||||||
model (:class:`torch.nn.Model`): the origin huggingface model |
|
||||||
dist_config (:class:`ShardConfig`): the config for distribute information |
|
||||||
custom_policy (:class:`Policy`): the custom policy for sharding |
|
||||||
""" |
|
||||||
|
|
||||||
def __init__( |
|
||||||
self, |
|
||||||
model: nn.Module, |
|
||||||
shard_config: ShardConfig = None, # TODO |
|
||||||
custom_policy: Policy = None, |
|
||||||
) -> None: |
|
||||||
self.model = model |
|
||||||
self.shard_config = shard_config |
|
||||||
self.policy = custom_policy |
|
||||||
# self.layout=, # TODO |
|
||||||
|
|
||||||
sharder = ModelSharder( |
|
||||||
model=self.model, |
|
||||||
policy=self.policy, |
|
||||||
shard_config=self.shard_config, |
|
||||||
) |
|
||||||
sharder.shard() |
|
||||||
|
|
||||||
def set_environ(self) -> None: |
|
||||||
os.environ["TOKENIZERS_PARALLELISM"] = "true" |
|
||||||
os.environ["MKL_SERVICE_FORCE_INTEL"] = "GNU" |
|
||||||
os.environ["MASTER_ADDR"] = str(self.dist_config.master_addr) |
|
||||||
os.environ["MASTER_PORT"] = str(self.dist_config.master_port) |
|
||||||
os.environ["WORLD_SIZE"] = str(self.dist_config.num_gpus) |
|
||||||
os.environ["RANK"] = str(self.dist_config.rank) |
|
||||||
os.environ["LOCAL_RANK"] = str(self.dist_config.rank) |
|
||||||
if not dist.is_initialized(): |
|
||||||
dist.init_process_group(backend=self.dist_config.backend) |
|
||||||
|
|
||||||
torch.cuda.set_device(int(os.getenv("LOCAL_RANK", "0"))) |
|
||||||
|
|
||||||
def back_to_org() -> None: |
|
||||||
pass |
|
Loading…
Reference in new issue