mirror of https://github.com/hpcaitech/ColossalAI
Browse Source
* polish code * shard strategy receive pg in shard() / gather() * update zero engine * polish codepull/455/head
ver217
3 years ago
committed by
GitHub
13 changed files with 84 additions and 80 deletions
@ -1,26 +1,21 @@
|
||||
from abc import ABC, abstractmethod |
||||
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor |
||||
import torch.distributed as dist |
||||
from typing import List, Optional |
||||
|
||||
import torch.distributed as dist |
||||
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor |
||||
|
||||
|
||||
class BaseShardStrategy(ABC): |
||||
|
||||
def __init__(self, process_group: Optional[dist.ProcessGroup] = None) -> None: |
||||
def __init__(self) -> None: |
||||
"""Abstract Shard Strategy. Use to shard a tensors on multiple GPUs. |
||||
|
||||
Args: |
||||
process_group (Optional[dist.ProcessGroup], optional): the process group. Defaults to None. |
||||
""" |
||||
self.process_group = process_group |
||||
self.world_size = dist.get_world_size(self.process_group) |
||||
self.local_rank = dist.get_rank(self.process_group) |
||||
super().__init__() |
||||
|
||||
@abstractmethod |
||||
def shard(self, tensor_list: List[ShardedTensor]): |
||||
def shard(self, tensor_list: List[ShardedTensor], process_group: Optional[dist.ProcessGroup] = None): |
||||
pass |
||||
|
||||
@abstractmethod |
||||
def gather(self, tensor_list: List[ShardedTensor]): |
||||
def gather(self, tensor_list: List[ShardedTensor], process_group: Optional[dist.ProcessGroup] = None): |
||||
pass |
||||
|
Loading…
Reference in new issue