mirror of https://github.com/hpcaitech/ColossalAI
[zero] Update initialize for ZeRO (#458)
* polish code * shard strategy receive pg in shard() / gather() * update zero engine * polish codepull/455/head
parent
642846d6f9
commit
a241f61b34
@ -1,26 +1,21 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
|
||||
import torch.distributed as dist
|
||||
from typing import List, Optional
|
||||
|
||||
import torch.distributed as dist
|
||||
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
|
||||
|
||||
|
||||
class BaseShardStrategy(ABC):
|
||||
|
||||
def __init__(self, process_group: Optional[dist.ProcessGroup] = None) -> None:
|
||||
def __init__(self) -> None:
|
||||
"""Abstract Shard Strategy. Use to shard a tensors on multiple GPUs.
|
||||
|
||||
Args:
|
||||
process_group (Optional[dist.ProcessGroup], optional): the process group. Defaults to None.
|
||||
"""
|
||||
self.process_group = process_group
|
||||
self.world_size = dist.get_world_size(self.process_group)
|
||||
self.local_rank = dist.get_rank(self.process_group)
|
||||
super().__init__()
|
||||
|
||||
@abstractmethod
|
||||
def shard(self, tensor_list: List[ShardedTensor]):
|
||||
def shard(self, tensor_list: List[ShardedTensor], process_group: Optional[dist.ProcessGroup] = None):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def gather(self, tensor_list: List[ShardedTensor]):
|
||||
def gather(self, tensor_list: List[ShardedTensor], process_group: Optional[dist.ProcessGroup] = None):
|
||||
pass
|
||||
|
Loading…
Reference in new issue