You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ColossalAI/colossalai/shardformer/shard/shardformer.py

52 lines
1.6 KiB

from typing import Dict, List, Tuple
import torch.nn as nn
from torch import Tensor
from colossalai.cluster import DistCoordinator
from ..policies.basepolicy import Policy
from .shard_config import ShardConfig
from .sharder import ModelSharder
class ShardFormer:
"""
Parallelize model based on the given config and policy
Example:
```python
from colossalai.shardformer import ShardFormer, ShardConfig
from transformers import BertForMaskedLM
import colossalai
import torch
colossalai.launch_from_torch(config={})
org_model = BertForMaskedLM.from_pretrained('bert-base-uncased')
shard_config = ShardConfig()
shard_former = ShardFormer(shard_config=shard_config)
model, shared_params = shard_former.optimize(org_model)
```
"""
def __init__(self, shard_config: ShardConfig):
self.coordinator = DistCoordinator()
self.shard_config = shard_config
def optimize(self, model: nn.Module, policy: Policy = None) -> Tuple[nn.Module, List[Dict[int, Tensor]]]:
r"""
This method will optimize the model based on the given policy.
Args:
model (`torch.nn.Model`): the origin huggingface model
shard_config (`ShardConfig`): the config for distribute information
policy (`Policy`): the custom policy for sharding
Returns: the sharded model and the shared parameters
"""
sharder = ModelSharder(model=model, shard_config=self.shard_config, policy=policy)
shared_params = sharder.shard()
return model, shared_params