2022-03-04 03:59:35 +00:00
|
|
|
from abc import ABC, abstractmethod
|
|
|
|
from typing import List, Optional
|
|
|
|
|
2022-03-18 08:18:31 +00:00
|
|
|
import torch.distributed as dist
|
2023-04-04 05:48:16 +00:00
|
|
|
|
2023-09-18 08:31:06 +00:00
|
|
|
from colossalai.legacy.zero.sharded_param.sharded_tensor import ShardedTensor
|
2022-03-18 08:18:31 +00:00
|
|
|
|
2022-03-04 03:59:35 +00:00
|
|
|
|
|
|
|
class BaseShardStrategy(ABC):
|
2022-03-18 08:18:31 +00:00
|
|
|
def __init__(self) -> None:
|
2023-09-19 06:20:26 +00:00
|
|
|
"""Abstract Shard Strategy. Use to shard a tensors on multiple GPUs."""
|
2022-03-04 03:59:35 +00:00
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
@abstractmethod
|
2022-03-18 08:18:31 +00:00
|
|
|
def shard(self, tensor_list: List[ShardedTensor], process_group: Optional[dist.ProcessGroup] = None):
|
2022-03-04 03:59:35 +00:00
|
|
|
pass
|
|
|
|
|
|
|
|
@abstractmethod
|
2022-03-18 08:18:31 +00:00
|
|
|
def gather(self, tensor_list: List[ShardedTensor], process_group: Optional[dist.ProcessGroup] = None):
|
2022-03-04 03:59:35 +00:00
|
|
|
pass
|