[pipeline] update shardformer docstring

pull/4445/head
ver217 1 year ago committed by Hongxin Liu
parent 59f6f573f1
commit b0b8ad2823

@ -1,4 +1,7 @@
from typing import Dict, List, Tuple
import torch.nn as nn import torch.nn as nn
from torch import Tensor
from colossalai.cluster import DistCoordinator from colossalai.cluster import DistCoordinator
@ -24,7 +27,7 @@ class ShardFormer:
org_model = BertForMaskedLM.from_pretrained('bert-base-uncased') org_model = BertForMaskedLM.from_pretrained('bert-base-uncased')
shard_config = ShardConfig() shard_config = ShardConfig()
shard_former = ShardFormer(shard_config=shard_config) shard_former = ShardFormer(shard_config=shard_config)
model = shard_former.optimize(org_model) model, shared_params = shard_former.optimize(org_model)
``` ```
""" """
@ -32,7 +35,7 @@ class ShardFormer:
self.coordinator = DistCoordinator() self.coordinator = DistCoordinator()
self.shard_config = shard_config self.shard_config = shard_config
def optimize(self, model: nn.Module, policy: Policy = None): def optimize(self, model: nn.Module, policy: Policy = None) -> Tuple[nn.Module, List[Dict[int, Tensor]]]:
r""" r"""
This method will optimize the model based on the given policy. This method will optimize the model based on the given policy.
@ -40,6 +43,8 @@ class ShardFormer:
model (`torch.nn.Model`): the origin huggingface model model (`torch.nn.Model`): the origin huggingface model
shard_config (`ShardConfig`): the config for distribute information shard_config (`ShardConfig`): the config for distribute information
policy (`Policy`): the custom policy for sharding policy (`Policy`): the custom policy for sharding
Returns: the sharded model and the shared parameters
""" """
sharder = ModelSharder(model=model, shard_config=self.shard_config, policy=policy) sharder = ModelSharder(model=model, shard_config=self.shard_config, policy=policy)
shared_params = sharder.shard() shared_params = sharder.shard()

Loading…
Cancel
Save