fixed using zero with tp cannot access weight correctly

pull/2933/head
zbian 2023-02-27 17:52:16 +08:00 committed by アマデウス
parent eb5cf94332
commit 61e687831d
2 changed files with 72 additions and 68 deletions

View File

@ -24,15 +24,18 @@ class ColossalaiModule(nn.Module):
def __init__(self, module: nn.Module, **kwargs):
super().__init__()
# copy values
self.__dict__ = module.__dict__.copy()
# copy methods
for name, attr in module.__class__.__dict__.items():
if name not in ['__init__', 'forward'] and callable(attr):
setattr(self, name, getattr(module, name))
self._forward_func = module.forward
self.module = module
for k, v in kwargs.items():
setattr(self, k, v)
def __getattr__(self, name: str):
if name == 'module':
return super().__getattr__(name)
elif hasattr(self.module, name):
return getattr(self.module, name)
elif name in self.__dict__:
return self.__dict__[name]
raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, name))
def forward(self, *args):
return self._forward_func(*args)
return self.module(*args)

View File

@ -1,4 +1,5 @@
import torch.nn as nn
from colossalai.context import ParallelMode, seed
from ..parallel_1d import *
@ -24,7 +25,7 @@ class Dropout(ColossalaiModule):
def forward(self, *args):
if self.tensor_parallel in [None, '1d']:
return self._forward_func(*args)
return super().forward(*args)
else:
with seed(ParallelMode.TENSOR):
return self._forward_func(*args)
return super().forward(*args)