mirror of https://github.com/hpcaitech/ColossalAI
19 lines
586 B
Python
19 lines
586 B
Python
from colossalai.zero.shard_utils import BaseShardStrategy
|
|
import torch.distributed as dist
|
|
from typing import List, Optional
|
|
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
|
|
|
|
|
|
class TensorShardStrategy(BaseShardStrategy):
|
|
|
|
def __init__(self, process_group: Optional[dist.ProcessGroup] = None) -> None:
|
|
super().__init__(process_group)
|
|
|
|
def shard(self, tensor_list: List[ShardedTensor]):
|
|
for t in tensor_list:
|
|
t.shard()
|
|
|
|
def gather(self, tensor_list: List[ShardedTensor]):
|
|
for t in tensor_list:
|
|
t.gather()
|