mirror of https://github.com/hpcaitech/ColossalAI
74 lines
2.2 KiB
Python
74 lines
2.2 KiB
Python
|
import torch.nn as nn
|
||
|
from torch.utils.data import Dataset
|
||
|
|
||
|
from colossalai.cluster import DistCoordinator, ProcessGroupManager
|
||
|
|
||
|
from ..policies.basepolicy import Policy
|
||
|
from .shard_config import ShardConfig
|
||
|
from .sharder import ModelSharder
|
||
|
|
||
|
|
||
|
class ShardFormer:
|
||
|
"""
|
||
|
Parallelize model based on the given config and policy
|
||
|
|
||
|
Example:
|
||
|
|
||
|
```python
|
||
|
from colossalai.shardformer import ShardFormer, ShardConfig
|
||
|
from transformers import BertForMaskedLM
|
||
|
import colossalai
|
||
|
import torch
|
||
|
|
||
|
colossalai.launch_from_torch(config={})
|
||
|
|
||
|
org_model = BertForMaskedLM.from_pretrained('bert-base-uncased')
|
||
|
shard_config = ShardConfig(
|
||
|
tensor_parallel_size=2,
|
||
|
tensor_parallel_mode='1d',
|
||
|
)
|
||
|
shard_former = ShardFormer(shard_config=shard_config)
|
||
|
shard_former.init_distributed()
|
||
|
model = shard_former.shard_model(org_model)
|
||
|
```
|
||
|
"""
|
||
|
|
||
|
def __init__(self, shard_config: ShardConfig):
|
||
|
"""
|
||
|
Do two things:
|
||
|
1. Create a colossalai.cluster.process_group_manager to manage process groups for dp, tp and pp
|
||
|
2. serve as a store for
|
||
|
"""
|
||
|
self.coordinator = DistCoordinator()
|
||
|
self.shard_config = shard_config
|
||
|
self.pg_manager = None
|
||
|
|
||
|
def init_distributed(self) -> ProcessGroupManager:
|
||
|
"""
|
||
|
Initialize the distributed process group according to the
|
||
|
"""
|
||
|
pg_manager = ProcessGroupManager()
|
||
|
if (self.shard_config.tensor_parallel_mode == '1d'):
|
||
|
pg_manager.create_process_group(name='tp1d', ranks=range(self.coordinator.world_size))
|
||
|
self.pg_manager = pg_manager
|
||
|
return pg_manager
|
||
|
|
||
|
def shard_model(self, model: nn.Module, policy: Policy = None):
|
||
|
r"""
|
||
|
The function is used to shard the PyTorch model.
|
||
|
|
||
|
Args:
|
||
|
model (`torch.nn.Model`): the origin huggingface model
|
||
|
shard_config (`ShardConfig`): the config for distribute information
|
||
|
policy (`Policy`): the custom policy for sharding
|
||
|
"""
|
||
|
sharder = ModelSharder(model=model, shard_config=self.shard_config, policy=policy, pg_manager=self.pg_manager)
|
||
|
sharder.shard()
|
||
|
return model
|
||
|
|
||
|
def shard_dataset(self, dataset: Dataset):
|
||
|
"""
|
||
|
Shard dataset for DP
|
||
|
"""
|
||
|
pass
|