mirror of https://github.com/hpcaitech/ColossalAI
[Tensor] make a simple net works with 1D row TP (#879)
parent
c4d903e64a
commit
7f76517a85
|
@ -157,5 +157,21 @@ class ColoTensor(object):
|
||||||
def backward(self, gradient: Optional[torch.Tensor] = None, retain_graph: bool = False):
|
def backward(self, gradient: Optional[torch.Tensor] = None, retain_graph: bool = False):
|
||||||
self._torch_tensor.backward(gradient=gradient, retain_graph=retain_graph)
|
self._torch_tensor.backward(gradient=gradient, retain_graph=retain_graph)
|
||||||
|
|
||||||
|
## TODO(fjr) we reduce redundency of the following code
|
||||||
def __add__(self, o) -> "ColoTensor":
|
def __add__(self, o) -> "ColoTensor":
|
||||||
return ColoTensor.init_from_torch_tensor(self.torch_tensor() + o.torch_tensor())
|
return ColoTensor.init_from_torch_tensor(self.torch_tensor() + o.torch_tensor())
|
||||||
|
|
||||||
|
def __truediv__(self, o) -> "ColoTensor":
|
||||||
|
return ColoTensor.init_from_torch_tensor(self.torch_tensor() / o)
|
||||||
|
|
||||||
|
def view(self, *args: int) -> "ColoTensor":
|
||||||
|
return ColoTensor.init_from_torch_tensor(self.torch_tensor().view(*args))
|
||||||
|
|
||||||
|
def permute(self, *args) -> "ColoTensor":
|
||||||
|
return ColoTensor.init_from_torch_tensor(self.torch_tensor().permute(*args))
|
||||||
|
|
||||||
|
def transpose(self, *args) -> "ColoTensor":
|
||||||
|
return ColoTensor.init_from_torch_tensor(self.torch_tensor().transpose(*args))
|
||||||
|
|
||||||
|
def contiguous(self):
|
||||||
|
return ColoTensor.init_from_torch_tensor(self.torch_tensor().contiguous())
|
||||||
|
|
|
@ -7,7 +7,8 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use
|
||||||
from colossalai.utils.cuda import get_current_device
|
from colossalai.utils.cuda import get_current_device
|
||||||
from colossalai.utils import free_port
|
from colossalai.utils import free_port
|
||||||
from colossalai.utils import ColoInitContext
|
from colossalai.utils import ColoInitContext
|
||||||
from colossalai.tensor import named_params_with_colotensor
|
from colossalai.tensor import named_params_with_colotensor, TensorSpec, ComputePattern, ParallelAction, ColoTensor
|
||||||
|
from colossalai.context import ParallelMode
|
||||||
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
|
@ -20,18 +21,32 @@ def run_simple_net():
|
||||||
with ColoInitContext(device=get_current_device()):
|
with ColoInitContext(device=get_current_device()):
|
||||||
model = model_builder(checkpoint=True)
|
model = model_builder(checkpoint=True)
|
||||||
|
|
||||||
|
parallel_action_list = [
|
||||||
|
ParallelAction(priority=1, compute_pattern=ComputePattern.TP1DRow, parallel_mode=ParallelMode.PARALLEL_1D)
|
||||||
|
]
|
||||||
|
spec = TensorSpec(parallel_action_list)
|
||||||
|
|
||||||
|
# A naive way to set spec for all weights in Linear
|
||||||
|
for name, p in named_params_with_colotensor(model):
|
||||||
|
if not isinstance(p, ColoTensor):
|
||||||
|
continue
|
||||||
|
if 'weight' in name and 'LayerNorm' not in name and 'ln' not in name and 'embed' not in name:
|
||||||
|
p.set_spec(spec)
|
||||||
|
|
||||||
|
model.cuda()
|
||||||
|
|
||||||
for param in named_params_with_colotensor(model):
|
for param in named_params_with_colotensor(model):
|
||||||
print(param)
|
print(param)
|
||||||
# we set the Specs for weight of each linear.
|
|
||||||
# model.proj1.weight.set_spec('1Drow')
|
|
||||||
# model.proj2.weight.set_spec('1Drow')
|
|
||||||
|
|
||||||
for i, (data, label) in enumerate(train_dataloader):
|
for i, (data, label) in enumerate(train_dataloader):
|
||||||
output = model(data)
|
data = data.to(get_current_device())
|
||||||
|
label = label.to(get_current_device())
|
||||||
|
|
||||||
if criterion:
|
if criterion:
|
||||||
|
output = model(data)
|
||||||
loss = criterion(output, label)
|
loss = criterion(output, label)
|
||||||
else:
|
else:
|
||||||
|
output = model(data, label)
|
||||||
loss = output
|
loss = output
|
||||||
|
|
||||||
print(loss.torch_tensor())
|
print(loss.torch_tensor())
|
||||||
|
|
Loading…
Reference in New Issue