mirror of https://github.com/hpcaitech/ColossalAI
[Tensor ] Add 1Drow weight reshard by spec (#854)
parent
d7e0303d1e
commit
bcc8655021
|
@ -6,6 +6,7 @@ from colossalai.nn.layer.parallel_1d._utils import split_forward_gather_backward
|
|||
from colossalai.nn.layer.utils import divide
|
||||
from colossalai.core import global_context as gpc
|
||||
from packaging import version
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
|
||||
@colo_op_impl(torch.nn.functional.linear)
|
||||
def colo_linear(types, args, kwargs, pg):
|
||||
|
@ -39,12 +40,15 @@ def colo_linear(types, args, kwargs, pg):
|
|||
# Input:S[1]
|
||||
input_per_partition = split_forward_gather_backward(input_tensor, ParallelMode.PARALLEL_1D, dim=-1)
|
||||
# Output:P
|
||||
partial_output = torch.nn.functional.linear(input_per_partition, weight.torch_tensor())
|
||||
device = get_current_device() # TODO where to put to(deivce)?
|
||||
weight_ = weight.torch_tensor().to(device)
|
||||
partial_output = torch.nn.functional.linear(input_per_partition, weight_)
|
||||
# Reduce(Output)
|
||||
output = reduce_input(partial_output, ParallelMode.PARALLEL_1D)
|
||||
# Bias
|
||||
if bias is not None:
|
||||
output = output + bias
|
||||
bias_ = bias.to(device)
|
||||
output = output + bias_
|
||||
return output
|
||||
|
||||
else:
|
||||
|
|
|
@ -3,7 +3,10 @@ from .op_wrapper import _COLOSSAL_OPS
|
|||
import torch
|
||||
from typing import Tuple, Optional
|
||||
from numpy import product
|
||||
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.nn.layer.utils import divide
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
|
||||
class ColoTensor(object):
|
||||
""" Data Structure for Tensor in Colossal-AI
|
||||
|
@ -85,6 +88,28 @@ class ColoTensor(object):
|
|||
device=self._device)
|
||||
return self._torch_tensor
|
||||
|
||||
def set_spec(self, spec: str, lazy_shard: bool=False) -> None:
|
||||
self._shard_spec = spec
|
||||
if lazy_shard == False:
|
||||
self._shard()
|
||||
|
||||
def _shard(self):
|
||||
assert self._shard_spec is not None, 'You should call set_spec() before _shard() ColoTensor.'
|
||||
if self._shard_spec == "1Drow": # TODO It actually represents the sharding layout for Linear-1Drow-weight, but we make it simpler now.
|
||||
num_partition = gpc.get_world_size(ParallelMode.TENSOR)
|
||||
local_rank = gpc.get_local_rank(ParallelMode.TENSOR)
|
||||
dim = -1
|
||||
chunk_size = divide(self._size[dim], num_partition)
|
||||
device = get_current_device()
|
||||
# Reshape to get shard for this rank and we don't want autograd
|
||||
# recording here for the narrow op and 'local_shard' should be a
|
||||
# leaf variable in the autograd graph.
|
||||
self._torch_tensor = self._torch_tensor.narrow(dim,
|
||||
local_rank * chunk_size, chunk_size).detach().contiguous() # TODO Shall we clone() here since detach() will point to the old tensor?
|
||||
self._torch_tensor.requires_grad = self._requires_grad
|
||||
self._size = self._torch_tensor.size()
|
||||
self._device = device # TODO A `fake` device now because torch_tensor.device always = cpu
|
||||
|
||||
@classmethod
|
||||
def __torch_function__(cls, func, types, args=(), kwargs=None):
|
||||
global _COLOSSAL_OPS
|
||||
|
|
|
@ -1,10 +1,11 @@
|
|||
from zmq import device
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from colossalai.nn import CheckpointModule
|
||||
from .utils.dummy_data_generator import DummyDataGenerator
|
||||
from .registry import non_distributed_component_funcs
|
||||
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
|
||||
class SimpleNet(CheckpointModule):
|
||||
"""
|
||||
|
@ -25,8 +26,8 @@ class SimpleNet(CheckpointModule):
|
|||
class DummyDataLoader(DummyDataGenerator):
|
||||
|
||||
def generate(self):
|
||||
data = torch.rand(16, 4)
|
||||
label = torch.randint(low=0, high=2, size=(16,))
|
||||
data = torch.rand(16, 4, device=get_current_device())
|
||||
label = torch.randint(low=0, high=2, size=(16,), device=get_current_device())
|
||||
return data, label
|
||||
|
||||
|
||||
|
|
|
@ -35,7 +35,7 @@ def run_linear_tp1d_row_test():
|
|||
|
||||
W_shape = (out_features, in_features)
|
||||
W_master = torch.randn(W_shape, dtype=dtype, device=device)
|
||||
W = broadcast_tensor_chunk(W_master, chunk_size=DEPTH, local_rank=local_rank)
|
||||
W = broadcast_tensor_chunk(W_master, chunk_size=1)
|
||||
W.requires_grad = True
|
||||
|
||||
B_shape = (out_features)
|
||||
|
@ -45,7 +45,7 @@ def run_linear_tp1d_row_test():
|
|||
|
||||
# replace the torch nn.Parameters with ColoTensor
|
||||
sharded_weight = ColoTensor.init_from_torch_tensor(W)
|
||||
sharded_weight._shard_spec = "1Drow"
|
||||
sharded_weight.set_spec(spec="1Drow") # reshard
|
||||
sharded_bias = ColoTensor.init_from_torch_tensor(B)
|
||||
replace_parameter_add_grad(layer, sharded_weight, sharded_bias)
|
||||
out = layer(A)
|
||||
|
|
|
@ -23,9 +23,9 @@ def run_simple_net():
|
|||
with ColoInitContext():
|
||||
model = model_builder(checkpoint=True)
|
||||
|
||||
# TODO(jzy) we set the Specs for weight of each linear.
|
||||
# model.proj1.weight.set_spec('1Drow')
|
||||
# model.proj2.weight.set_spec('1Drow')
|
||||
# we set the Specs for weight of each linear.
|
||||
model.proj1.weight.set_spec('1Drow')
|
||||
model.proj2.weight.set_spec('1Drow')
|
||||
|
||||
for i, (data, label) in enumerate(train_dataloader):
|
||||
output = model(data)
|
||||
|
|
Loading…
Reference in New Issue