mirror of https://github.com/hpcaitech/ColossalAI
YuliangLiu0306
2 years ago
committed by
GitHub
3 changed files with 284 additions and 0 deletions
@ -0,0 +1,158 @@
|
||||
from typing import Optional |
||||
|
||||
import torch |
||||
from torch.utils._pytree import tree_map |
||||
|
||||
from colossalai.device.device_mesh import DeviceMesh |
||||
from colossalai.tensor.d_tensor.layout import Layout |
||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager, to_global |
||||
from colossalai.tensor.sharding_spec import ShardingSpec |
||||
|
||||
shape_consistency_manager = ShapeConsistencyManager() |
||||
|
||||
|
||||
class DTensor(torch.Tensor): |
||||
|
||||
def __init__(self, local_tensor: torch.Tensor, dist_layout: Layout): |
||||
self.local_tensor = local_tensor |
||||
self.data_type = local_tensor.dtype |
||||
self.entire_shape = local_tensor.shape |
||||
if dist_layout.entire_shape is None: |
||||
dist_layout.entire_shape = self.entire_shape |
||||
self.dist_layout = dist_layout |
||||
self._apply_layout() |
||||
|
||||
@staticmethod |
||||
def __new__(cls, local_tensor, layout): |
||||
return torch.Tensor._make_subclass(cls, local_tensor, local_tensor.requires_grad) |
||||
|
||||
def __repr__(self): |
||||
return f"DTensor({self.to_global()}, {self.dist_layout})" |
||||
|
||||
def __str__(self): |
||||
return self.__repr__() |
||||
|
||||
def layout_convert(self, target_layout): |
||||
''' |
||||
Convert the layout of the tensor from source_spec to target_spec. |
||||
''' |
||||
source_spec = convert_layout_to_sharding_spec(self.dist_layout) |
||||
target_spec = convert_layout_to_sharding_spec(target_layout) |
||||
self.local_tensor = shape_consistency_manager.apply_for_autoparallel_runtime( |
||||
self.local_tensor, source_spec, target_spec) |
||||
self.dist_layout = target_layout |
||||
|
||||
def _apply_layout(self): |
||||
''' |
||||
Apply the layout to the local tensor during initializing process. |
||||
''' |
||||
source_spec = construct_default_sharding_spec(self.local_tensor, self.device_mesh) |
||||
target_spec = convert_layout_to_sharding_spec(self.dist_layout) |
||||
self.local_tensor = shape_consistency_manager.apply_for_autoparallel_runtime( |
||||
self.local_tensor, source_spec, target_spec) |
||||
|
||||
@classmethod |
||||
def __torch_function__(cls, func, types, args=(), kwargs=None): |
||||
if kwargs is None: |
||||
kwargs = {} |
||||
|
||||
def filter_arg(arg): |
||||
if isinstance(arg, DTensor): |
||||
return arg.local_tensor |
||||
else: |
||||
return arg |
||||
|
||||
args = tree_map(filter_arg, args) |
||||
kwargs = tree_map(filter_arg, kwargs) |
||||
# if we want to convert the result into DTensor, we need to infer the layout of result from the layout of input tensors |
||||
# and op type. |
||||
|
||||
return func(*args, **kwargs) |
||||
|
||||
@property |
||||
def device_mesh(self): |
||||
''' |
||||
Return the device mesh of the tensor. |
||||
''' |
||||
return self.dist_layout.device_mesh |
||||
|
||||
@property |
||||
def sharding_spec(self): |
||||
''' |
||||
Return the sharding specification of the tensor. |
||||
''' |
||||
return self.dist_layout.sharding_spec |
||||
|
||||
def to(self, *args, **kwargs): |
||||
''' |
||||
Move the tensor to a new device or convert the tensor to a new dtype. |
||||
''' |
||||
self.local_tensor = self.local_tensor.to(*args, **kwargs) |
||||
self.data_type = self.local_tensor.dtype |
||||
self.dist_layout.device_type = self.local_tensor.device |
||||
# TODO: update the device mesh process groups or we should just cache |
||||
# both the cpu process groups and the cuda process groups? |
||||
return self |
||||
|
||||
def to_local(self): |
||||
''' |
||||
Return the local tensor in this rank. |
||||
''' |
||||
return self.local_tensor |
||||
|
||||
def to_global(self): |
||||
''' |
||||
Recover the global tensor from the distributed tensor. |
||||
|
||||
Note: This function will all_gather the local tensor to the global tensor and it |
||||
will not change the layout of the DTensor. This function is mainly used for debugging or |
||||
check the correctness of the distributed tensor. |
||||
''' |
||||
return to_global(self.local_tensor, convert_layout_to_sharding_spec(self.dist_layout)) |
||||
|
||||
|
||||
def distribute_tensor(local_tensor: torch.Tensor, dist_layout: Layout) -> DTensor: |
||||
''' |
||||
Distribute the local tensor to the distributed tensor according to the dist_layout specified. |
||||
|
||||
Args: |
||||
local_tensor: tensor to be distributed. |
||||
dist_layout: the layout specification of the distributed tensor. |
||||
|
||||
Returns: |
||||
A 'DTensor' object. |
||||
''' |
||||
return DTensor(local_tensor, dist_layout) |
||||
|
||||
|
||||
def distribute_module(module: torch.nn.Module, partition_fn: Optional[callable] = None) -> torch.nn.Module: |
||||
''' |
||||
This function converts all the parameters in the module to DTensor(DParam). |
||||
|
||||
Note: This function is subject to future change as the DParam has not been implemented yet. |
||||
''' |
||||
for name, param in module.named_parameters(): |
||||
if param is not None and not isinstance(param, DTensor): |
||||
# TODO: we could convert the parameter to DParam here, |
||||
# the type of the parameter could be an optional argument. |
||||
setattr(module, name, torch.nn.Parameter(partition_fn(name, param.data))) |
||||
return module |
||||
|
||||
|
||||
def convert_layout_to_sharding_spec(layout: Layout) -> ShardingSpec: |
||||
''' |
||||
Convert the layout from Layout class to ShardingSpec class. |
||||
''' |
||||
return ShardingSpec(device_mesh=layout.device_mesh, |
||||
entire_shape=layout.entire_shape, |
||||
dim_partition_dict=layout.sharding_spec.dim_partition_dict) |
||||
|
||||
|
||||
def construct_default_sharding_spec( |
||||
tensor: torch.Tensor, |
||||
device_mesh: DeviceMesh, |
||||
) -> ShardingSpec: |
||||
''' |
||||
Construct the default sharding specification for the tensor. |
||||
''' |
||||
return ShardingSpec(device_mesh=device_mesh, entire_shape=tensor.shape, dim_partition_dict={}) |
@ -0,0 +1,22 @@
|
||||
from dataclasses import dataclass |
||||
|
||||
import torch |
||||
|
||||
from colossalai.device.device_mesh import DeviceMesh |
||||
from colossalai.tensor.sharding_spec import ShardingSpec |
||||
|
||||
|
||||
@dataclass |
||||
class Layout: |
||||
"""Layout of a tensor. |
||||
|
||||
Attributes: |
||||
device_mesh: the device mesh to store the tensor distributedly. |
||||
device_type: the type of the device mesh, e.g. 'cpu' or 'cuda'. |
||||
sharding_spec: the sharding specification to describe how the tensor is sharded. |
||||
entire_shape: the entire shape of the global tensor. |
||||
""" |
||||
device_mesh: DeviceMesh |
||||
device_type: torch.device |
||||
sharding_spec: ShardingSpec |
||||
entire_shape: torch.Size = None |
@ -0,0 +1,104 @@
|
||||
from functools import partial |
||||
|
||||
import torch |
||||
import torch.multiprocessing as mp |
||||
|
||||
from colossalai.device.device_mesh import DeviceMesh |
||||
from colossalai.fx.tracer import ColoTracer |
||||
from colossalai.initialize import launch |
||||
from colossalai.logging import disable_existing_loggers |
||||
from colossalai.tensor.d_tensor.d_tensor import DTensor, distribute_tensor |
||||
from colossalai.tensor.d_tensor.layout import Layout |
||||
from colossalai.tensor.sharding_spec import ShardingSpec |
||||
from colossalai.utils import free_port |
||||
|
||||
|
||||
class TestModel(torch.nn.Module): |
||||
|
||||
def __init__(self, in_features, out_features): |
||||
super().__init__() |
||||
self.linear_1 = torch.nn.Linear(in_features, out_features) |
||||
self.linear_2 = torch.nn.Linear(out_features, in_features) |
||||
|
||||
def forward(self, x): |
||||
x = self.linear_1(x) |
||||
x = self.linear_2(x) |
||||
return x |
||||
|
||||
|
||||
def check_dtensor(rank, world_size, port): |
||||
disable_existing_loggers() |
||||
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') |
||||
test_model = TestModel(8, 8).to('cuda') |
||||
original_tensor = torch.rand(4, 8).to('cuda') |
||||
compare_output = test_model(original_tensor) |
||||
|
||||
device_mesh = DeviceMesh(torch.Tensor([0, 1, 2, 3]), (2, 2), init_process_group=True) |
||||
target_sharding_spec = ShardingSpec(device_mesh=device_mesh, |
||||
entire_shape=original_tensor.shape, |
||||
dim_partition_dict={0: [0]}) |
||||
layout = Layout(device_mesh=device_mesh, device_type=torch.device('cuda'), sharding_spec=target_sharding_spec) |
||||
d_tensor = DTensor(original_tensor, layout) |
||||
|
||||
assert d_tensor.entire_shape == original_tensor.shape |
||||
assert d_tensor.data_type == original_tensor.dtype |
||||
|
||||
if rank in (0, 1): |
||||
assert d_tensor.to_local().equal(original_tensor.narrow(0, 0, 2)) |
||||
elif rank in (2, 3): |
||||
assert d_tensor.to_local().equal(original_tensor.narrow(0, 2, 2)) |
||||
else: |
||||
raise ValueError(f'rank {rank} is not in the device mesh') |
||||
assert d_tensor.to_global().equal(original_tensor) |
||||
output = test_model(d_tensor) |
||||
|
||||
if rank in (0, 1): |
||||
assert output.equal(compare_output.narrow(0, 0, 2)) |
||||
elif rank in (2, 3): |
||||
assert output.equal(compare_output.narrow(0, 2, 2)) |
||||
else: |
||||
raise ValueError(f'rank {rank} is not in the device mesh') |
||||
|
||||
new_sharding_spec = ShardingSpec(device_mesh=device_mesh, |
||||
entire_shape=original_tensor.shape, |
||||
dim_partition_dict={0: [0, 1]}) |
||||
new_layout = Layout(device_mesh=device_mesh, |
||||
device_type=torch.device('cuda'), |
||||
sharding_spec=new_sharding_spec, |
||||
entire_shape=original_tensor.shape) |
||||
|
||||
d_tensor.layout_convert(new_layout) |
||||
|
||||
if rank == 0: |
||||
assert d_tensor.local_tensor.equal(original_tensor.narrow(0, 0, 1)) |
||||
elif rank == 1: |
||||
assert d_tensor.local_tensor.equal(original_tensor.narrow(0, 1, 1)) |
||||
elif rank == 2: |
||||
assert d_tensor.local_tensor.equal(original_tensor.narrow(0, 2, 1)) |
||||
elif rank == 3: |
||||
assert d_tensor.local_tensor.equal(original_tensor.narrow(0, 3, 1)) |
||||
else: |
||||
raise ValueError(f'rank {rank} is not in the device mesh') |
||||
|
||||
dtensor_from_local = distribute_tensor(original_tensor, new_layout) |
||||
|
||||
if rank == 0: |
||||
assert dtensor_from_local.local_tensor.equal(original_tensor.narrow(0, 0, 1)) |
||||
elif rank == 1: |
||||
assert dtensor_from_local.local_tensor.equal(original_tensor.narrow(0, 1, 1)) |
||||
elif rank == 2: |
||||
assert dtensor_from_local.local_tensor.equal(original_tensor.narrow(0, 2, 1)) |
||||
elif rank == 3: |
||||
assert dtensor_from_local.local_tensor.equal(original_tensor.narrow(0, 3, 1)) |
||||
else: |
||||
raise ValueError(f'rank {rank} is not in the device mesh') |
||||
|
||||
|
||||
def test_dtensor(): |
||||
world_size = 4 |
||||
run_func = partial(check_dtensor, world_size=world_size, port=free_port()) |
||||
mp.spawn(run_func, nprocs=world_size) |
||||
|
||||
|
||||
if __name__ == '__main__': |
||||
test_dtensor() |
Loading…
Reference in new issue