diff --git a/colossalai/tensor/_ops/linear.py b/colossalai/tensor/_ops/linear.py index a57599e6e..824ce702c 100644 --- a/colossalai/tensor/_ops/linear.py +++ b/colossalai/tensor/_ops/linear.py @@ -8,6 +8,7 @@ from colossalai.core import global_context as gpc from packaging import version from colossalai.utils.cuda import get_current_device + @colo_op_impl(torch.nn.functional.linear) def colo_linear(types, args, kwargs, pg): """Handles ``__torch_function__`` dispatch for ``torch.nn.functional.linear``. @@ -34,13 +35,13 @@ def colo_linear(types, args, kwargs, pg): elif weight.shard_spec == '1Drow': # Input:S[1] x Weight:S[0] = Output:P # All-Reduce(Output) + bias = res - assert divide(input_tensor.shape[-1], gpc.tensor_parallel_size) == weight.size[-1], \ + assert divide(input_tensor.shape[-1], gpc.tensor_parallel_size) == weight.size(-1), \ 'Invalid shapes in 1Drow forward: input={}, weight={}. Expected last dim of input {}.'.format( input_tensor.shape, weight.size, weight.size[-1] * gpc.tensor_parallel_size) # Input:S[1] input_per_partition = split_forward_gather_backward(input_tensor, ParallelMode.PARALLEL_1D, dim=-1) # Output:P - device = get_current_device() # TODO where to put to(deivce)? + device = get_current_device() # TODO where to put to(deivce)? weight_ = weight.torch_tensor().to(device) partial_output = torch.nn.functional.linear(input_per_partition, weight_) # Reduce(Output) @@ -50,7 +51,7 @@ def colo_linear(types, args, kwargs, pg): bias_ = bias.to(device) output = output + bias_ return output - + else: raise NotImplementedError else: