mirror of https://github.com/hpcaitech/ColossalAI
[hotfix] fix unit test test_module_spec (#1321)
parent
9e4c6449b0
commit
1b41686461
|
@ -88,7 +88,7 @@ def init_colo_module(module: torch.nn.Module,
|
|||
compute_pattern = compute_spec.compute_pattern
|
||||
if is_colo_module(module):
|
||||
# for each param
|
||||
# set DistSpec and ComputeSpec
|
||||
# set its process_group, dist_spec and compute_spec
|
||||
colo_module = get_colo_module(module)
|
||||
colo_module.register(compute_pattern, pg)
|
||||
if not colo_module.has_compute_pattern_with_mode(compute_pattern, mode=mode):
|
||||
|
@ -101,6 +101,7 @@ def init_colo_module(module: torch.nn.Module,
|
|||
continue
|
||||
param = module.get_parameter(param_name)
|
||||
if isinstance(param, ColoParameter):
|
||||
param.set_process_group(pg)
|
||||
param.set_dist_spec(dist_spec)
|
||||
param.compute_spec = compute_spec
|
||||
for mod in param.shared_param_modules:
|
||||
|
|
|
@ -18,7 +18,7 @@ def _get_my_nowrap_functions() -> Set[Callable]:
|
|||
Tensor._base.__get__,
|
||||
Tensor.grad.__get__,
|
||||
Tensor._grad.__get__,
|
||||
Tensor.data.__get__, # make .data returns torch.Tensor rather than ColoTensor
|
||||
Tensor.data.__get__, # make .data returns torch.Tensor rather than ColoTensor
|
||||
}
|
||||
|
||||
|
||||
|
@ -121,11 +121,13 @@ class ColoTensor(torch.Tensor):
|
|||
RuntimeError:
|
||||
"""
|
||||
assert isinstance(pg, ProcessGroup), f"pg as type {type(pg)} is invalid"
|
||||
if self.process_group.tp_world_size() != 1:
|
||||
raise RuntimeError("can not set_process_group on a ColoTensor whose process_group has tp world group")
|
||||
|
||||
if self.dist_spec.placement.value != 'r':
|
||||
raise RuntimeError("can not set_process_group on a ColoTensor whose dist spec is not REPLICATE")
|
||||
# if the new pg is the same as the old pg, just returns
|
||||
if self.process_group == pg:
|
||||
return
|
||||
assert self.process_group.tp_world_size() == 1, \
|
||||
"Can not set_process_group on a ColoTensor whose process_group has tp world group"
|
||||
assert self.dist_spec.placement.value == 'r', \
|
||||
"Can not set_process_group on a ColoTensor whose dist spec is not REPLICATE"
|
||||
|
||||
self.process_group = pg
|
||||
|
||||
|
@ -290,17 +292,17 @@ class ColoTensor(torch.Tensor):
|
|||
|
||||
def is_replicate(self):
|
||||
return self.dist_spec.placement == DistPlacementPattern.REPLICATE \
|
||||
or (len(self.dist_spec.num_partitions) == 1
|
||||
and self.dist_spec.num_partitions[0] == 1) \
|
||||
or (self.process_group.tp_world_size() == 1)
|
||||
or (len(self.dist_spec.num_partitions) == 1
|
||||
and self.dist_spec.num_partitions[0] == 1) \
|
||||
or (self.process_group.tp_world_size() == 1)
|
||||
|
||||
def is_shard_1dcol(self):
|
||||
return self.dist_spec.placement == DistPlacementPattern.SHARD \
|
||||
and len(self.dist_spec.dims) == 1 and self.dist_spec.dims[0] == -1
|
||||
and len(self.dist_spec.dims) == 1 and self.dist_spec.dims[0] == -1
|
||||
|
||||
def is_shard_1drow(self):
|
||||
return self.dist_spec.placement == DistPlacementPattern.SHARD \
|
||||
and len(self.dist_spec.dims) == 1 and self.dist_spec.dims[0] == 0
|
||||
and len(self.dist_spec.dims) == 1 and self.dist_spec.dims[0] == 0
|
||||
|
||||
def is_sharded(self):
|
||||
return self.dist_spec.placement == DistPlacementPattern.SHARD
|
||||
|
|
|
@ -1,11 +1,11 @@
|
|||
from copy import copy
|
||||
from copy import deepcopy
|
||||
import pytest
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
from colossalai.tensor import ColoTensorSpec, ComputePattern, ComputeSpec, ShardSpec, ReplicaSpec
|
||||
from colossalai.tensor import ColoTensor, ComputePattern, ComputeSpec, ShardSpec, ColoTensorSpec
|
||||
from colossalai.nn.parallel.layers import init_colo_module, check_colo_module
|
||||
from _utils import tensor_equal, tensor_shard_equal, set_seed
|
||||
|
||||
|
@ -112,21 +112,25 @@ def run_linear_with_spec(mode):
|
|||
with ColoInitContext(device=get_current_device()):
|
||||
model = torch.nn.Linear(4, 8)
|
||||
|
||||
model_handy = copy(model)
|
||||
model_handy = deepcopy(model)
|
||||
world_size = torch.distributed.get_world_size()
|
||||
pg = ProcessGroup(tp_degree=world_size)
|
||||
compute_spec = ComputeSpec(ComputePattern.TP1D)
|
||||
init_colo_module(model, compute_spec, pg=pg, recursive=True, mode=mode)
|
||||
|
||||
x = torch.rand(2, 4).cuda()
|
||||
colo_x = ColoTensor.from_torch_tensor(x, ColoTensorSpec(pg))
|
||||
|
||||
out = model(x)
|
||||
colo_out = model_handy(x)
|
||||
colo_out = model_handy(colo_x)
|
||||
assert tensor_equal(out, colo_out)
|
||||
|
||||
grad = torch.rand_like(out)
|
||||
out.backward(grad)
|
||||
colo_out.backward(grad)
|
||||
assert tensor_shard_equal(model.weight.grad, model_handy.weight.grad, pg.tp_local_rank(), pg.tp_world_size())
|
||||
assert tensor_shard_equal(model.bias.grad, model_handy.bias.grad, pg.tp_local_rank(), pg.tp_world_size())
|
||||
|
||||
assert tensor_shard_equal(model_handy.weight.grad, model.weight.grad, pg.tp_local_rank(), pg.tp_world_size())
|
||||
assert tensor_shard_equal(model_handy.bias.grad, model.bias.grad, pg.tp_local_rank(), pg.tp_world_size())
|
||||
|
||||
|
||||
def run_check_shared_param():
|
||||
|
@ -196,7 +200,7 @@ def run_dist_check(rank, world_size, port):
|
|||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize('world_size', [1, 4])
|
||||
@pytest.mark.skip("under development lazy init ColoParameter in Context")
|
||||
@pytest.mark.skip("for higher testing speed")
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_module_linear_1d(world_size):
|
||||
run_func = partial(run_dist, world_size=world_size, port=free_port())
|
||||
|
@ -205,7 +209,7 @@ def test_module_linear_1d(world_size):
|
|||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize('world_size', [1, 4])
|
||||
@pytest.mark.skip("under development lazy init ColoParameter in Context")
|
||||
@pytest.mark.skip("for higher testing speed")
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_module_model(world_size):
|
||||
run_func = partial(run_dist_model, world_size=world_size, port=free_port())
|
||||
|
@ -214,7 +218,7 @@ def test_module_model(world_size):
|
|||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize('world_size', [1, 2])
|
||||
@pytest.mark.skip("under development lazy init ColoParameter in Context")
|
||||
@pytest.mark.skip("for higher testing speed")
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_module_check(world_size):
|
||||
run_func = partial(run_dist_check, world_size=world_size, port=free_port())
|
||||
|
@ -222,4 +226,4 @@ def test_module_check(world_size):
|
|||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_module_check(2)
|
||||
test_module_linear_1d(4)
|
||||
|
|
Loading…
Reference in New Issue