mirror of https://github.com/hpcaitech/ColossalAI
23 lines
645 B
Python
23 lines
645 B
Python
![]() |
from abc import ABC, abstractmethod
|
||
|
from typing import List, Optional
|
||
|
|
||
![]() |
import torch.distributed as dist
|
||
![]() |
|
||
|
from colossalai.zero.legacy.sharded_param.sharded_tensor import ShardedTensor
|
||
![]() |
|
||
![]() |
|
||
|
class BaseShardStrategy(ABC):
|
||
|
|
||
![]() |
def __init__(self) -> None:
|
||
![]() |
"""Abstract Shard Strategy. Use to shard a tensors on multiple GPUs.
|
||
|
"""
|
||
![]() |
super().__init__()
|
||
|
|
||
|
@abstractmethod
|
||
![]() |
def shard(self, tensor_list: List[ShardedTensor], process_group: Optional[dist.ProcessGroup] = None):
|
||
![]() |
pass
|
||
|
|
||
|
@abstractmethod
|
||
![]() |
def gather(self, tensor_list: List[ShardedTensor], process_group: Optional[dist.ProcessGroup] = None):
|
||
![]() |
pass
|