From 61e687831d19ab5ad5e3255ac980807046bb8dc3 Mon Sep 17 00:00:00 2001 From: zbian Date: Mon, 27 Feb 2023 17:52:16 +0800 Subject: [PATCH] fixed using zero with tp cannot access weight correctly --- .../nn/layer/colossalai_layer/_utils.py | 79 ++++++++++--------- .../nn/layer/colossalai_layer/dropout.py | 61 +++++++------- 2 files changed, 72 insertions(+), 68 deletions(-) diff --git a/colossalai/nn/layer/colossalai_layer/_utils.py b/colossalai/nn/layer/colossalai_layer/_utils.py index 4283e5fe0..677cb0e7a 100644 --- a/colossalai/nn/layer/colossalai_layer/_utils.py +++ b/colossalai/nn/layer/colossalai_layer/_utils.py @@ -1,38 +1,41 @@ -import torch.nn as nn -from torch import Tensor - -from ..parallel_2d._operation import split_batch_2d -from ..parallel_2p5d._operation import split_batch_2p5d -from ..parallel_3d._operation import split_batch_3d -from ..utils import get_tensor_parallel_mode - -_parallel_split_batch = {'2d': split_batch_2d, '2.5d': split_batch_2p5d, '3d': split_batch_3d} - - -def partition_batch(input_) -> Tensor: - tensor_parallel_mode = get_tensor_parallel_mode() - if tensor_parallel_mode in _parallel_split_batch: - if isinstance(input_, dict): - return {k: _parallel_split_batch[tensor_parallel_mode](v) for k, v in input_.items()} - else: - return _parallel_split_batch[tensor_parallel_mode](input_) - else: - return input_ - - -class ColossalaiModule(nn.Module): - - def __init__(self, module: nn.Module, **kwargs): - super().__init__() - # copy values - self.__dict__ = module.__dict__.copy() - # copy methods - for name, attr in module.__class__.__dict__.items(): - if name not in ['__init__', 'forward'] and callable(attr): - setattr(self, name, getattr(module, name)) - self._forward_func = module.forward - for k, v in kwargs.items(): - setattr(self, k, v) - - def forward(self, *args): - return self._forward_func(*args) +import torch.nn as nn +from torch import Tensor + +from ..parallel_2d._operation import split_batch_2d +from ..parallel_2p5d._operation import split_batch_2p5d +from ..parallel_3d._operation import split_batch_3d +from ..utils import get_tensor_parallel_mode + +_parallel_split_batch = {'2d': split_batch_2d, '2.5d': split_batch_2p5d, '3d': split_batch_3d} + + +def partition_batch(input_) -> Tensor: + tensor_parallel_mode = get_tensor_parallel_mode() + if tensor_parallel_mode in _parallel_split_batch: + if isinstance(input_, dict): + return {k: _parallel_split_batch[tensor_parallel_mode](v) for k, v in input_.items()} + else: + return _parallel_split_batch[tensor_parallel_mode](input_) + else: + return input_ + + +class ColossalaiModule(nn.Module): + + def __init__(self, module: nn.Module, **kwargs): + super().__init__() + self.module = module + for k, v in kwargs.items(): + setattr(self, k, v) + + def __getattr__(self, name: str): + if name == 'module': + return super().__getattr__(name) + elif hasattr(self.module, name): + return getattr(self.module, name) + elif name in self.__dict__: + return self.__dict__[name] + raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, name)) + + def forward(self, *args): + return self.module(*args) diff --git a/colossalai/nn/layer/colossalai_layer/dropout.py b/colossalai/nn/layer/colossalai_layer/dropout.py index cc2d9a0a7..0c049cb3f 100644 --- a/colossalai/nn/layer/colossalai_layer/dropout.py +++ b/colossalai/nn/layer/colossalai_layer/dropout.py @@ -1,30 +1,31 @@ -import torch.nn as nn -from colossalai.context import ParallelMode, seed - -from ..parallel_1d import * -from ..utils import get_tensor_parallel_mode -from ._utils import ColossalaiModule - - -class Dropout(ColossalaiModule): - """Dropout layer of colossalai. - - Args: - p (float, optional): probability of an element to be zeroed, defaults 0.5. - inplace (bool, optional): whether to do dropout in-place, default to be False. - """ - - def __init__(self, p: float = 0.5, inplace: bool = False) -> None: - tensor_parallel = get_tensor_parallel_mode() - if tensor_parallel == "1d": - drop = Dropout1D(p, inplace) - else: - drop = nn.Dropout(p, inplace) - super().__init__(drop, tensor_parallel=tensor_parallel) - - def forward(self, *args): - if self.tensor_parallel in [None, '1d']: - return self._forward_func(*args) - else: - with seed(ParallelMode.TENSOR): - return self._forward_func(*args) +import torch.nn as nn + +from colossalai.context import ParallelMode, seed + +from ..parallel_1d import * +from ..utils import get_tensor_parallel_mode +from ._utils import ColossalaiModule + + +class Dropout(ColossalaiModule): + """Dropout layer of colossalai. + + Args: + p (float, optional): probability of an element to be zeroed, defaults 0.5. + inplace (bool, optional): whether to do dropout in-place, default to be False. + """ + + def __init__(self, p: float = 0.5, inplace: bool = False) -> None: + tensor_parallel = get_tensor_parallel_mode() + if tensor_parallel == "1d": + drop = Dropout1D(p, inplace) + else: + drop = nn.Dropout(p, inplace) + super().__init__(drop, tensor_parallel=tensor_parallel) + + def forward(self, *args): + if self.tensor_parallel in [None, '1d']: + return super().forward(*args) + else: + with seed(ParallelMode.TENSOR): + return super().forward(*args)