diff --git a/colossalai/shardformer/shard/shardformer.py b/colossalai/shardformer/shard/shardformer.py index 069a46ca5..6e0f90257 100644 --- a/colossalai/shardformer/shard/shardformer.py +++ b/colossalai/shardformer/shard/shardformer.py @@ -1,4 +1,7 @@ +from typing import Dict, List, Tuple + import torch.nn as nn +from torch import Tensor from colossalai.cluster import DistCoordinator @@ -24,7 +27,7 @@ class ShardFormer: org_model = BertForMaskedLM.from_pretrained('bert-base-uncased') shard_config = ShardConfig() shard_former = ShardFormer(shard_config=shard_config) - model = shard_former.optimize(org_model) + model, shared_params = shard_former.optimize(org_model) ``` """ @@ -32,7 +35,7 @@ class ShardFormer: self.coordinator = DistCoordinator() self.shard_config = shard_config - def optimize(self, model: nn.Module, policy: Policy = None): + def optimize(self, model: nn.Module, policy: Policy = None) -> Tuple[nn.Module, List[Dict[int, Tensor]]]: r""" This method will optimize the model based on the given policy. @@ -40,6 +43,8 @@ class ShardFormer: model (`torch.nn.Model`): the origin huggingface model shard_config (`ShardConfig`): the config for distribute information policy (`Policy`): the custom policy for sharding + + Returns: the sharded model and the shared parameters """ sharder = ModelSharder(model=model, shard_config=self.shard_config, policy=policy) shared_params = sharder.shard()