mirror of https://github.com/hpcaitech/ColossalAI
28 lines
827 B
Python
28 lines
827 B
Python
|
from abc import ABC, abstractmethod
|
||
|
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
|
||
|
import torch.distributed as dist
|
||
|
from typing import List, Optional
|
||
|
|
||
|
|
||
|
class BaseShardStrategy(ABC):
|
||
|
|
||
|
def __init__(self, process_group: Optional[dist.ProcessGroup] = None) -> None:
|
||
|
self.process_group = process_group
|
||
|
self.world_size = dist.get_world_size(self.process_group)
|
||
|
self.local_rank = dist.get_rank(self.process_group)
|
||
|
super().__init__()
|
||
|
|
||
|
@abstractmethod
|
||
|
def shard(self, tensor_list: List[ShardedTensor]):
|
||
|
r"""
|
||
|
sharded the memory of tensor on multiple processes.
|
||
|
"""
|
||
|
pass
|
||
|
|
||
|
@abstractmethod
|
||
|
def gather(self, tensor_list: List[ShardedTensor]):
|
||
|
r"""
|
||
|
duplicate tensor payload on each processes.
|
||
|
"""
|
||
|
pass
|