From 29159d9b5b1c2e4dd2dac7b795b6eeeb092dd993 Mon Sep 17 00:00:00 2001 From: Jiarui Fang Date: Mon, 25 Apr 2022 10:06:53 +0800 Subject: [PATCH] hotfix tensor unittest bugs (#862) --- colossalai/tensor/_ops/linear.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/colossalai/tensor/_ops/linear.py b/colossalai/tensor/_ops/linear.py index a57599e6e..824ce702c 100644 --- a/colossalai/tensor/_ops/linear.py +++ b/colossalai/tensor/_ops/linear.py @@ -8,6 +8,7 @@ from colossalai.core import global_context as gpc from packaging import version from colossalai.utils.cuda import get_current_device + @colo_op_impl(torch.nn.functional.linear) def colo_linear(types, args, kwargs, pg): """Handles ``__torch_function__`` dispatch for ``torch.nn.functional.linear``. @@ -34,13 +35,13 @@ def colo_linear(types, args, kwargs, pg): elif weight.shard_spec == '1Drow': # Input:S[1] x Weight:S[0] = Output:P # All-Reduce(Output) + bias = res - assert divide(input_tensor.shape[-1], gpc.tensor_parallel_size) == weight.size[-1], \ + assert divide(input_tensor.shape[-1], gpc.tensor_parallel_size) == weight.size(-1), \ 'Invalid shapes in 1Drow forward: input={}, weight={}. Expected last dim of input {}.'.format( input_tensor.shape, weight.size, weight.size[-1] * gpc.tensor_parallel_size) # Input:S[1] input_per_partition = split_forward_gather_backward(input_tensor, ParallelMode.PARALLEL_1D, dim=-1) # Output:P - device = get_current_device() # TODO where to put to(deivce)? + device = get_current_device() # TODO where to put to(deivce)? weight_ = weight.torch_tensor().to(device) partial_output = torch.nn.functional.linear(input_per_partition, weight_) # Reduce(Output) @@ -50,7 +51,7 @@ def colo_linear(types, args, kwargs, pg): bias_ = bias.to(device) output = output + bias_ return output - + else: raise NotImplementedError else: