mirror of https://github.com/hpcaitech/ColossalAI
fixed using zero with tp cannot access weight correctly
parent
eb5cf94332
commit
61e687831d
|
@ -1,38 +1,41 @@
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
from ..parallel_2d._operation import split_batch_2d
|
from ..parallel_2d._operation import split_batch_2d
|
||||||
from ..parallel_2p5d._operation import split_batch_2p5d
|
from ..parallel_2p5d._operation import split_batch_2p5d
|
||||||
from ..parallel_3d._operation import split_batch_3d
|
from ..parallel_3d._operation import split_batch_3d
|
||||||
from ..utils import get_tensor_parallel_mode
|
from ..utils import get_tensor_parallel_mode
|
||||||
|
|
||||||
_parallel_split_batch = {'2d': split_batch_2d, '2.5d': split_batch_2p5d, '3d': split_batch_3d}
|
_parallel_split_batch = {'2d': split_batch_2d, '2.5d': split_batch_2p5d, '3d': split_batch_3d}
|
||||||
|
|
||||||
|
|
||||||
def partition_batch(input_) -> Tensor:
|
def partition_batch(input_) -> Tensor:
|
||||||
tensor_parallel_mode = get_tensor_parallel_mode()
|
tensor_parallel_mode = get_tensor_parallel_mode()
|
||||||
if tensor_parallel_mode in _parallel_split_batch:
|
if tensor_parallel_mode in _parallel_split_batch:
|
||||||
if isinstance(input_, dict):
|
if isinstance(input_, dict):
|
||||||
return {k: _parallel_split_batch[tensor_parallel_mode](v) for k, v in input_.items()}
|
return {k: _parallel_split_batch[tensor_parallel_mode](v) for k, v in input_.items()}
|
||||||
else:
|
else:
|
||||||
return _parallel_split_batch[tensor_parallel_mode](input_)
|
return _parallel_split_batch[tensor_parallel_mode](input_)
|
||||||
else:
|
else:
|
||||||
return input_
|
return input_
|
||||||
|
|
||||||
|
|
||||||
class ColossalaiModule(nn.Module):
|
class ColossalaiModule(nn.Module):
|
||||||
|
|
||||||
def __init__(self, module: nn.Module, **kwargs):
|
def __init__(self, module: nn.Module, **kwargs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# copy values
|
self.module = module
|
||||||
self.__dict__ = module.__dict__.copy()
|
for k, v in kwargs.items():
|
||||||
# copy methods
|
setattr(self, k, v)
|
||||||
for name, attr in module.__class__.__dict__.items():
|
|
||||||
if name not in ['__init__', 'forward'] and callable(attr):
|
def __getattr__(self, name: str):
|
||||||
setattr(self, name, getattr(module, name))
|
if name == 'module':
|
||||||
self._forward_func = module.forward
|
return super().__getattr__(name)
|
||||||
for k, v in kwargs.items():
|
elif hasattr(self.module, name):
|
||||||
setattr(self, k, v)
|
return getattr(self.module, name)
|
||||||
|
elif name in self.__dict__:
|
||||||
def forward(self, *args):
|
return self.__dict__[name]
|
||||||
return self._forward_func(*args)
|
raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, name))
|
||||||
|
|
||||||
|
def forward(self, *args):
|
||||||
|
return self.module(*args)
|
||||||
|
|
|
@ -1,30 +1,31 @@
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from colossalai.context import ParallelMode, seed
|
|
||||||
|
from colossalai.context import ParallelMode, seed
|
||||||
from ..parallel_1d import *
|
|
||||||
from ..utils import get_tensor_parallel_mode
|
from ..parallel_1d import *
|
||||||
from ._utils import ColossalaiModule
|
from ..utils import get_tensor_parallel_mode
|
||||||
|
from ._utils import ColossalaiModule
|
||||||
|
|
||||||
class Dropout(ColossalaiModule):
|
|
||||||
"""Dropout layer of colossalai.
|
class Dropout(ColossalaiModule):
|
||||||
|
"""Dropout layer of colossalai.
|
||||||
Args:
|
|
||||||
p (float, optional): probability of an element to be zeroed, defaults 0.5.
|
Args:
|
||||||
inplace (bool, optional): whether to do dropout in-place, default to be False.
|
p (float, optional): probability of an element to be zeroed, defaults 0.5.
|
||||||
"""
|
inplace (bool, optional): whether to do dropout in-place, default to be False.
|
||||||
|
"""
|
||||||
def __init__(self, p: float = 0.5, inplace: bool = False) -> None:
|
|
||||||
tensor_parallel = get_tensor_parallel_mode()
|
def __init__(self, p: float = 0.5, inplace: bool = False) -> None:
|
||||||
if tensor_parallel == "1d":
|
tensor_parallel = get_tensor_parallel_mode()
|
||||||
drop = Dropout1D(p, inplace)
|
if tensor_parallel == "1d":
|
||||||
else:
|
drop = Dropout1D(p, inplace)
|
||||||
drop = nn.Dropout(p, inplace)
|
else:
|
||||||
super().__init__(drop, tensor_parallel=tensor_parallel)
|
drop = nn.Dropout(p, inplace)
|
||||||
|
super().__init__(drop, tensor_parallel=tensor_parallel)
|
||||||
def forward(self, *args):
|
|
||||||
if self.tensor_parallel in [None, '1d']:
|
def forward(self, *args):
|
||||||
return self._forward_func(*args)
|
if self.tensor_parallel in [None, '1d']:
|
||||||
else:
|
return super().forward(*args)
|
||||||
with seed(ParallelMode.TENSOR):
|
else:
|
||||||
return self._forward_func(*args)
|
with seed(ParallelMode.TENSOR):
|
||||||
|
return super().forward(*args)
|
||||||
|
|
Loading…
Reference in New Issue