mirror of https://github.com/hpcaitech/ColossalAI
[pipeline] update shardformer docstring
parent
59f6f573f1
commit
b0b8ad2823
|
@ -1,4 +1,7 @@
|
||||||
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
from colossalai.cluster import DistCoordinator
|
from colossalai.cluster import DistCoordinator
|
||||||
|
|
||||||
|
@ -24,7 +27,7 @@ class ShardFormer:
|
||||||
org_model = BertForMaskedLM.from_pretrained('bert-base-uncased')
|
org_model = BertForMaskedLM.from_pretrained('bert-base-uncased')
|
||||||
shard_config = ShardConfig()
|
shard_config = ShardConfig()
|
||||||
shard_former = ShardFormer(shard_config=shard_config)
|
shard_former = ShardFormer(shard_config=shard_config)
|
||||||
model = shard_former.optimize(org_model)
|
model, shared_params = shard_former.optimize(org_model)
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@ -32,7 +35,7 @@ class ShardFormer:
|
||||||
self.coordinator = DistCoordinator()
|
self.coordinator = DistCoordinator()
|
||||||
self.shard_config = shard_config
|
self.shard_config = shard_config
|
||||||
|
|
||||||
def optimize(self, model: nn.Module, policy: Policy = None):
|
def optimize(self, model: nn.Module, policy: Policy = None) -> Tuple[nn.Module, List[Dict[int, Tensor]]]:
|
||||||
r"""
|
r"""
|
||||||
This method will optimize the model based on the given policy.
|
This method will optimize the model based on the given policy.
|
||||||
|
|
||||||
|
@ -40,6 +43,8 @@ class ShardFormer:
|
||||||
model (`torch.nn.Model`): the origin huggingface model
|
model (`torch.nn.Model`): the origin huggingface model
|
||||||
shard_config (`ShardConfig`): the config for distribute information
|
shard_config (`ShardConfig`): the config for distribute information
|
||||||
policy (`Policy`): the custom policy for sharding
|
policy (`Policy`): the custom policy for sharding
|
||||||
|
|
||||||
|
Returns: the sharded model and the shared parameters
|
||||||
"""
|
"""
|
||||||
sharder = ModelSharder(model=model, shard_config=self.shard_config, policy=policy)
|
sharder = ModelSharder(model=model, shard_config=self.shard_config, policy=policy)
|
||||||
shared_params = sharder.shard()
|
shared_params = sharder.shard()
|
||||||
|
|
Loading…
Reference in New Issue