mirror of https://github.com/hpcaitech/ColossalAI
zbian
2 years ago
committed by
アマデウス
2 changed files with 72 additions and 68 deletions
@ -1,38 +1,41 @@
|
||||
import torch.nn as nn |
||||
from torch import Tensor |
||||
|
||||
from ..parallel_2d._operation import split_batch_2d |
||||
from ..parallel_2p5d._operation import split_batch_2p5d |
||||
from ..parallel_3d._operation import split_batch_3d |
||||
from ..utils import get_tensor_parallel_mode |
||||
|
||||
_parallel_split_batch = {'2d': split_batch_2d, '2.5d': split_batch_2p5d, '3d': split_batch_3d} |
||||
|
||||
|
||||
def partition_batch(input_) -> Tensor: |
||||
tensor_parallel_mode = get_tensor_parallel_mode() |
||||
if tensor_parallel_mode in _parallel_split_batch: |
||||
if isinstance(input_, dict): |
||||
return {k: _parallel_split_batch[tensor_parallel_mode](v) for k, v in input_.items()} |
||||
else: |
||||
return _parallel_split_batch[tensor_parallel_mode](input_) |
||||
else: |
||||
return input_ |
||||
|
||||
|
||||
class ColossalaiModule(nn.Module): |
||||
|
||||
def __init__(self, module: nn.Module, **kwargs): |
||||
super().__init__() |
||||
# copy values |
||||
self.__dict__ = module.__dict__.copy() |
||||
# copy methods |
||||
for name, attr in module.__class__.__dict__.items(): |
||||
if name not in ['__init__', 'forward'] and callable(attr): |
||||
setattr(self, name, getattr(module, name)) |
||||
self._forward_func = module.forward |
||||
for k, v in kwargs.items(): |
||||
setattr(self, k, v) |
||||
|
||||
def forward(self, *args): |
||||
return self._forward_func(*args) |
||||
import torch.nn as nn |
||||
from torch import Tensor |
||||
|
||||
from ..parallel_2d._operation import split_batch_2d |
||||
from ..parallel_2p5d._operation import split_batch_2p5d |
||||
from ..parallel_3d._operation import split_batch_3d |
||||
from ..utils import get_tensor_parallel_mode |
||||
|
||||
_parallel_split_batch = {'2d': split_batch_2d, '2.5d': split_batch_2p5d, '3d': split_batch_3d} |
||||
|
||||
|
||||
def partition_batch(input_) -> Tensor: |
||||
tensor_parallel_mode = get_tensor_parallel_mode() |
||||
if tensor_parallel_mode in _parallel_split_batch: |
||||
if isinstance(input_, dict): |
||||
return {k: _parallel_split_batch[tensor_parallel_mode](v) for k, v in input_.items()} |
||||
else: |
||||
return _parallel_split_batch[tensor_parallel_mode](input_) |
||||
else: |
||||
return input_ |
||||
|
||||
|
||||
class ColossalaiModule(nn.Module): |
||||
|
||||
def __init__(self, module: nn.Module, **kwargs): |
||||
super().__init__() |
||||
self.module = module |
||||
for k, v in kwargs.items(): |
||||
setattr(self, k, v) |
||||
|
||||
def __getattr__(self, name: str): |
||||
if name == 'module': |
||||
return super().__getattr__(name) |
||||
elif hasattr(self.module, name): |
||||
return getattr(self.module, name) |
||||
elif name in self.__dict__: |
||||
return self.__dict__[name] |
||||
raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, name)) |
||||
|
||||
def forward(self, *args): |
||||
return self.module(*args) |
||||
|
@ -1,30 +1,31 @@
|
||||
import torch.nn as nn |
||||
from colossalai.context import ParallelMode, seed |
||||
|
||||
from ..parallel_1d import * |
||||
from ..utils import get_tensor_parallel_mode |
||||
from ._utils import ColossalaiModule |
||||
|
||||
|
||||
class Dropout(ColossalaiModule): |
||||
"""Dropout layer of colossalai. |
||||
|
||||
Args: |
||||
p (float, optional): probability of an element to be zeroed, defaults 0.5. |
||||
inplace (bool, optional): whether to do dropout in-place, default to be False. |
||||
""" |
||||
|
||||
def __init__(self, p: float = 0.5, inplace: bool = False) -> None: |
||||
tensor_parallel = get_tensor_parallel_mode() |
||||
if tensor_parallel == "1d": |
||||
drop = Dropout1D(p, inplace) |
||||
else: |
||||
drop = nn.Dropout(p, inplace) |
||||
super().__init__(drop, tensor_parallel=tensor_parallel) |
||||
|
||||
def forward(self, *args): |
||||
if self.tensor_parallel in [None, '1d']: |
||||
return self._forward_func(*args) |
||||
else: |
||||
with seed(ParallelMode.TENSOR): |
||||
return self._forward_func(*args) |
||||
import torch.nn as nn |
||||
|
||||
from colossalai.context import ParallelMode, seed |
||||
|
||||
from ..parallel_1d import * |
||||
from ..utils import get_tensor_parallel_mode |
||||
from ._utils import ColossalaiModule |
||||
|
||||
|
||||
class Dropout(ColossalaiModule): |
||||
"""Dropout layer of colossalai. |
||||
|
||||
Args: |
||||
p (float, optional): probability of an element to be zeroed, defaults 0.5. |
||||
inplace (bool, optional): whether to do dropout in-place, default to be False. |
||||
""" |
||||
|
||||
def __init__(self, p: float = 0.5, inplace: bool = False) -> None: |
||||
tensor_parallel = get_tensor_parallel_mode() |
||||
if tensor_parallel == "1d": |
||||
drop = Dropout1D(p, inplace) |
||||
else: |
||||
drop = nn.Dropout(p, inplace) |
||||
super().__init__(drop, tensor_parallel=tensor_parallel) |
||||
|
||||
def forward(self, *args): |
||||
if self.tensor_parallel in [None, '1d']: |
||||
return super().forward(*args) |
||||
else: |
||||
with seed(ParallelMode.TENSOR): |
||||
return super().forward(*args) |
||||
|
Loading…
Reference in new issue