mirror of https://github.com/hpcaitech/ColossalAI
fixed using zero with tp cannot access weight correctly
parent
eb5cf94332
commit
61e687831d
|
@ -24,15 +24,18 @@ class ColossalaiModule(nn.Module):
|
|||
|
||||
def __init__(self, module: nn.Module, **kwargs):
|
||||
super().__init__()
|
||||
# copy values
|
||||
self.__dict__ = module.__dict__.copy()
|
||||
# copy methods
|
||||
for name, attr in module.__class__.__dict__.items():
|
||||
if name not in ['__init__', 'forward'] and callable(attr):
|
||||
setattr(self, name, getattr(module, name))
|
||||
self._forward_func = module.forward
|
||||
self.module = module
|
||||
for k, v in kwargs.items():
|
||||
setattr(self, k, v)
|
||||
|
||||
def __getattr__(self, name: str):
|
||||
if name == 'module':
|
||||
return super().__getattr__(name)
|
||||
elif hasattr(self.module, name):
|
||||
return getattr(self.module, name)
|
||||
elif name in self.__dict__:
|
||||
return self.__dict__[name]
|
||||
raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, name))
|
||||
|
||||
def forward(self, *args):
|
||||
return self._forward_func(*args)
|
||||
return self.module(*args)
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import torch.nn as nn
|
||||
|
||||
from colossalai.context import ParallelMode, seed
|
||||
|
||||
from ..parallel_1d import *
|
||||
|
@ -24,7 +25,7 @@ class Dropout(ColossalaiModule):
|
|||
|
||||
def forward(self, *args):
|
||||
if self.tensor_parallel in [None, '1d']:
|
||||
return self._forward_func(*args)
|
||||
return super().forward(*args)
|
||||
else:
|
||||
with seed(ParallelMode.TENSOR):
|
||||
return self._forward_func(*args)
|
||||
return super().forward(*args)
|
||||
|
|
Loading…
Reference in New Issue