ColossalAI/colossalai/zero/shard_utils/base_shard_strategy.py

27 lines
873 B
Python
Raw Normal View History

from abc import ABC, abstractmethod
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
import torch.distributed as dist
from typing import List, Optional
class BaseShardStrategy(ABC):
def __init__(self, process_group: Optional[dist.ProcessGroup] = None) -> None:
"""Abstract Shard Strategy. Use to shard a tensors on multiple GPUs.
Args:
process_group (Optional[dist.ProcessGroup], optional): the process group. Defaults to None.
"""
self.process_group = process_group
self.world_size = dist.get_world_size(self.process_group)
self.local_rank = dist.get_rank(self.process_group)
super().__init__()
@abstractmethod
def shard(self, tensor_list: List[ShardedTensor]):
pass
@abstractmethod
def gather(self, tensor_list: List[ShardedTensor]):
pass