ColossalAI/colossalai/zero/shard_utils/tensor_shard_strategy.py

19 lines
586 B
Python

from colossalai.zero.shard_utils import BaseShardStrategy
import torch.distributed as dist
from typing import List, Optional
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
class TensorShardStrategy(BaseShardStrategy):
def __init__(self, process_group: Optional[dist.ProcessGroup] = None) -> None:
super().__init__(process_group)
def shard(self, tensor_list: List[ShardedTensor]):
for t in tensor_list:
t.shard()
def gather(self, tensor_list: List[ShardedTensor]):
for t in tensor_list:
t.gather()