fixed using zero with tp cannot access weight correctly

pull/2933/head
zbian 2023-02-27 17:52:16 +08:00 committed by アマデウス
parent eb5cf94332
commit 61e687831d
2 changed files with 72 additions and 68 deletions

View File

@ -1,38 +1,41 @@
import torch.nn as nn import torch.nn as nn
from torch import Tensor from torch import Tensor
from ..parallel_2d._operation import split_batch_2d from ..parallel_2d._operation import split_batch_2d
from ..parallel_2p5d._operation import split_batch_2p5d from ..parallel_2p5d._operation import split_batch_2p5d
from ..parallel_3d._operation import split_batch_3d from ..parallel_3d._operation import split_batch_3d
from ..utils import get_tensor_parallel_mode from ..utils import get_tensor_parallel_mode
_parallel_split_batch = {'2d': split_batch_2d, '2.5d': split_batch_2p5d, '3d': split_batch_3d} _parallel_split_batch = {'2d': split_batch_2d, '2.5d': split_batch_2p5d, '3d': split_batch_3d}
def partition_batch(input_) -> Tensor: def partition_batch(input_) -> Tensor:
tensor_parallel_mode = get_tensor_parallel_mode() tensor_parallel_mode = get_tensor_parallel_mode()
if tensor_parallel_mode in _parallel_split_batch: if tensor_parallel_mode in _parallel_split_batch:
if isinstance(input_, dict): if isinstance(input_, dict):
return {k: _parallel_split_batch[tensor_parallel_mode](v) for k, v in input_.items()} return {k: _parallel_split_batch[tensor_parallel_mode](v) for k, v in input_.items()}
else: else:
return _parallel_split_batch[tensor_parallel_mode](input_) return _parallel_split_batch[tensor_parallel_mode](input_)
else: else:
return input_ return input_
class ColossalaiModule(nn.Module): class ColossalaiModule(nn.Module):
def __init__(self, module: nn.Module, **kwargs): def __init__(self, module: nn.Module, **kwargs):
super().__init__() super().__init__()
# copy values self.module = module
self.__dict__ = module.__dict__.copy() for k, v in kwargs.items():
# copy methods setattr(self, k, v)
for name, attr in module.__class__.__dict__.items():
if name not in ['__init__', 'forward'] and callable(attr): def __getattr__(self, name: str):
setattr(self, name, getattr(module, name)) if name == 'module':
self._forward_func = module.forward return super().__getattr__(name)
for k, v in kwargs.items(): elif hasattr(self.module, name):
setattr(self, k, v) return getattr(self.module, name)
elif name in self.__dict__:
def forward(self, *args): return self.__dict__[name]
return self._forward_func(*args) raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, name))
def forward(self, *args):
return self.module(*args)

View File

@ -1,30 +1,31 @@
import torch.nn as nn import torch.nn as nn
from colossalai.context import ParallelMode, seed
from colossalai.context import ParallelMode, seed
from ..parallel_1d import *
from ..utils import get_tensor_parallel_mode from ..parallel_1d import *
from ._utils import ColossalaiModule from ..utils import get_tensor_parallel_mode
from ._utils import ColossalaiModule
class Dropout(ColossalaiModule):
"""Dropout layer of colossalai. class Dropout(ColossalaiModule):
"""Dropout layer of colossalai.
Args:
p (float, optional): probability of an element to be zeroed, defaults 0.5. Args:
inplace (bool, optional): whether to do dropout in-place, default to be False. p (float, optional): probability of an element to be zeroed, defaults 0.5.
""" inplace (bool, optional): whether to do dropout in-place, default to be False.
"""
def __init__(self, p: float = 0.5, inplace: bool = False) -> None:
tensor_parallel = get_tensor_parallel_mode() def __init__(self, p: float = 0.5, inplace: bool = False) -> None:
if tensor_parallel == "1d": tensor_parallel = get_tensor_parallel_mode()
drop = Dropout1D(p, inplace) if tensor_parallel == "1d":
else: drop = Dropout1D(p, inplace)
drop = nn.Dropout(p, inplace) else:
super().__init__(drop, tensor_parallel=tensor_parallel) drop = nn.Dropout(p, inplace)
super().__init__(drop, tensor_parallel=tensor_parallel)
def forward(self, *args):
if self.tensor_parallel in [None, '1d']: def forward(self, *args):
return self._forward_func(*args) if self.tensor_parallel in [None, '1d']:
else: return super().forward(*args)
with seed(ParallelMode.TENSOR): else:
return self._forward_func(*args) with seed(ParallelMode.TENSOR):
return super().forward(*args)