from abc import ABC, abstractmethod
from typing import List, Optional

import torch.distributed as dist
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor


class BaseShardStrategy(ABC):

    def __init__(self) -> None:
        """Abstract Shard Strategy. Use to shard a tensors on multiple GPUs.
        """
        super().__init__()

    @abstractmethod
    def shard(self, tensor_list: List[ShardedTensor], process_group: Optional[dist.ProcessGroup] = None):
        pass

    @abstractmethod
    def gather(self, tensor_list: List[ShardedTensor], process_group: Optional[dist.ProcessGroup] = None):
        pass