mirror of https://github.com/hpcaitech/ColossalAI
[TP] Add gather_out arg to Linear (#541)
parent
8c90d4df54
commit
763dc325f1
|
@ -1,4 +1,5 @@
|
|||
import math
|
||||
import inspect
|
||||
from typing import Callable
|
||||
|
||||
from colossalai.utils import get_current_device
|
||||
|
@ -78,15 +79,19 @@ class Linear(nn.Module):
|
|||
if self.layer.bias is not None:
|
||||
bias_initializer(self.layer.bias, fan_in=in_features)
|
||||
else:
|
||||
self.layer = _parallel_linear[tensor_parallel](
|
||||
in_features,
|
||||
out_features,
|
||||
bias=bias,
|
||||
dtype=dtype,
|
||||
weight_initializer=weight_initializer,
|
||||
bias_initializer=bias_initializer,
|
||||
**kwargs,
|
||||
)
|
||||
linear_cls = _parallel_linear[tensor_parallel]
|
||||
gather_output = kwargs.pop('gather_output', None)
|
||||
if 'gather_output' in inspect.signature(linear_cls.__init__).parameters.keys(): # gather_out arg is available
|
||||
kwargs['gather_output'] = gather_output
|
||||
self.layer = linear_cls(
|
||||
in_features,
|
||||
out_features,
|
||||
bias=bias,
|
||||
dtype=dtype,
|
||||
weight_initializer=weight_initializer,
|
||||
bias_initializer=bias_initializer,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@property
|
||||
def weight(self):
|
||||
|
|
Loading…
Reference in New Issue