2023-07-05 06:19:12 +00:00
|
|
|
from typing import Dict, List, Tuple
|
|
|
|
|
2023-06-15 09:55:42 +00:00
|
|
|
import torch.nn as nn
|
2023-07-05 06:19:12 +00:00
|
|
|
from torch import Tensor
|
2023-06-15 09:55:42 +00:00
|
|
|
|
2023-06-30 01:58:08 +00:00
|
|
|
from colossalai.cluster import DistCoordinator
|
2023-06-15 09:55:42 +00:00
|
|
|
|
2023-07-05 07:13:00 +00:00
|
|
|
from ..policies.base_policy import Policy
|
2023-06-15 09:55:42 +00:00
|
|
|
from .shard_config import ShardConfig
|
|
|
|
from .sharder import ModelSharder
|
|
|
|
|
|
|
|
|
|
|
|
class ShardFormer:
|
|
|
|
"""
|
|
|
|
Parallelize model based on the given config and policy
|
|
|
|
|
|
|
|
Example:
|
|
|
|
|
|
|
|
```python
|
|
|
|
from colossalai.shardformer import ShardFormer, ShardConfig
|
|
|
|
from transformers import BertForMaskedLM
|
|
|
|
import colossalai
|
|
|
|
import torch
|
|
|
|
|
|
|
|
colossalai.launch_from_torch(config={})
|
|
|
|
|
|
|
|
org_model = BertForMaskedLM.from_pretrained('bert-base-uncased')
|
2023-07-03 07:29:11 +00:00
|
|
|
shard_config = ShardConfig()
|
2023-06-15 09:55:42 +00:00
|
|
|
shard_former = ShardFormer(shard_config=shard_config)
|
2023-07-05 06:19:12 +00:00
|
|
|
model, shared_params = shard_former.optimize(org_model)
|
2023-06-15 09:55:42 +00:00
|
|
|
```
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(self, shard_config: ShardConfig):
|
|
|
|
self.coordinator = DistCoordinator()
|
|
|
|
self.shard_config = shard_config
|
|
|
|
|
2023-07-05 06:19:12 +00:00
|
|
|
def optimize(self, model: nn.Module, policy: Policy = None) -> Tuple[nn.Module, List[Dict[int, Tensor]]]:
|
2023-06-15 09:55:42 +00:00
|
|
|
r"""
|
2023-07-03 07:29:11 +00:00
|
|
|
This method will optimize the model based on the given policy.
|
2023-06-15 09:55:42 +00:00
|
|
|
|
|
|
|
Args:
|
|
|
|
model (`torch.nn.Model`): the origin huggingface model
|
|
|
|
shard_config (`ShardConfig`): the config for distribute information
|
|
|
|
policy (`Policy`): the custom policy for sharding
|
2023-07-05 06:19:12 +00:00
|
|
|
|
|
|
|
Returns: the sharded model and the shared parameters
|
2023-06-15 09:55:42 +00:00
|
|
|
"""
|
2023-06-30 01:58:08 +00:00
|
|
|
sharder = ModelSharder(model=model, shard_config=self.shard_config, policy=policy)
|
2023-07-05 06:16:55 +00:00
|
|
|
shared_params = sharder.shard()
|
|
|
|
return model, shared_params
|