mirror of https://github.com/hpcaitech/ColossalAI
[autoparallel] update binary elementwise handler (#2451)
* [autoparallel] update binary elementwise handler * polishpull/2396/head
parent
c9ec5190a0
commit
8221fd7485
|
@ -32,20 +32,32 @@ class BinaryElementwiseHandler(MetaInfoNodeHandler):
|
||||||
return OperationDataType.ARG
|
return OperationDataType.ARG
|
||||||
|
|
||||||
def _get_arg_value(idx):
|
def _get_arg_value(idx):
|
||||||
|
non_tensor = False
|
||||||
if isinstance(self.node.args[idx], Node):
|
if isinstance(self.node.args[idx], Node):
|
||||||
meta_data = self.node.args[idx]._meta_data
|
meta_data = self.node.args[idx]._meta_data
|
||||||
|
# The meta_data of node type argument could also possibly be a non-tensor object.
|
||||||
|
if not isinstance(meta_data, torch.Tensor):
|
||||||
|
assert isinstance(meta_data, (int, float))
|
||||||
|
meta_data = torch.Tensor([meta_data]).to('meta')
|
||||||
|
non_tensor = True
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# this is in fact a real data like int 1
|
# this is in fact a real data like int 1
|
||||||
# but we can deem it as meta data
|
# but we can deem it as meta data
|
||||||
# as it won't affect the strategy generation
|
# as it won't affect the strategy generation
|
||||||
assert isinstance(self.node.args[idx], (int, float))
|
assert isinstance(self.node.args[idx], (int, float))
|
||||||
meta_data = torch.Tensor([self.node.args[idx]]).to('meta')
|
meta_data = torch.Tensor([self.node.args[idx]]).to('meta')
|
||||||
return meta_data
|
non_tensor = True
|
||||||
|
|
||||||
input_meta_data = _get_arg_value(0)
|
return meta_data, non_tensor
|
||||||
other_meta_data = _get_arg_value(1)
|
|
||||||
|
input_meta_data, non_tensor_input = _get_arg_value(0)
|
||||||
|
other_meta_data, non_tensor_other = _get_arg_value(1)
|
||||||
output_meta_data = self.node._meta_data
|
output_meta_data = self.node._meta_data
|
||||||
|
# we need record op_data with non-tensor data in this list,
|
||||||
|
# and filter the non-tensor op_data in post_process.
|
||||||
|
self.non_tensor_list = []
|
||||||
|
# assert False
|
||||||
input_op_data = OperationData(name=str(self.node.args[0]),
|
input_op_data = OperationData(name=str(self.node.args[0]),
|
||||||
type=_get_op_data_type(input_meta_data),
|
type=_get_op_data_type(input_meta_data),
|
||||||
data=input_meta_data,
|
data=input_meta_data,
|
||||||
|
@ -58,6 +70,10 @@ class BinaryElementwiseHandler(MetaInfoNodeHandler):
|
||||||
type=OperationDataType.OUTPUT,
|
type=OperationDataType.OUTPUT,
|
||||||
data=output_meta_data,
|
data=output_meta_data,
|
||||||
logical_shape=bcast_shape)
|
logical_shape=bcast_shape)
|
||||||
|
if non_tensor_input:
|
||||||
|
self.non_tensor_list.append(input_op_data)
|
||||||
|
if non_tensor_other:
|
||||||
|
self.non_tensor_list.append(other_op_data)
|
||||||
|
|
||||||
mapping = {'input': input_op_data, 'other': other_op_data, 'output': output_op_data}
|
mapping = {'input': input_op_data, 'other': other_op_data, 'output': output_op_data}
|
||||||
return mapping
|
return mapping
|
||||||
|
@ -73,9 +89,10 @@ class BinaryElementwiseHandler(MetaInfoNodeHandler):
|
||||||
op_data_mapping = self.get_operation_data_mapping()
|
op_data_mapping = self.get_operation_data_mapping()
|
||||||
|
|
||||||
for op_name, op_data in op_data_mapping.items():
|
for op_name, op_data in op_data_mapping.items():
|
||||||
if not isinstance(op_data.data, torch.Tensor):
|
if op_data in self.non_tensor_list:
|
||||||
# remove the sharding spec if the op_data is not a tensor, e.g. torch.pow(tensor, 2)
|
# remove the sharding spec if the op_data is not a tensor, e.g. torch.pow(tensor, 2)
|
||||||
strategy.sharding_specs.pop(op_data)
|
strategy.sharding_specs.pop(op_data)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# convert the logical sharding spec to physical sharding spec if broadcast
|
# convert the logical sharding spec to physical sharding spec if broadcast
|
||||||
# e.g. torch.rand(4, 4) + torch.rand(4)
|
# e.g. torch.rand(4, 4) + torch.rand(4)
|
||||||
|
|
|
@ -122,25 +122,41 @@ def check_binary_elementwise_handler_with_tensor(rank, op, other_dim, world_size
|
||||||
assert input_sharding_spec.sharding_sequence[-1] == other_sharding_spec.sharding_sequence[-1]
|
assert input_sharding_spec.sharding_sequence[-1] == other_sharding_spec.sharding_sequence[-1]
|
||||||
|
|
||||||
|
|
||||||
def check_binary_elementwise_handler_with_int(rank, op, other_dim, world_size, port):
|
class BEOpModelWithNodeConst(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, op):
|
||||||
|
super().__init__()
|
||||||
|
self.op = op
|
||||||
|
|
||||||
|
def forward(self, x1):
|
||||||
|
const = x1.dim()
|
||||||
|
out = self.op(x1, const)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class BEOpModelWithIntConst(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, op, const):
|
||||||
|
super().__init__()
|
||||||
|
self.op = op
|
||||||
|
self.const = const
|
||||||
|
|
||||||
|
def forward(self, x1):
|
||||||
|
out = self.op(x1, self.const)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def check_binary_elementwise_handler_with_int(rank, op, other_dim, model_cls, world_size, port):
|
||||||
disable_existing_loggers()
|
disable_existing_loggers()
|
||||||
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||||
|
|
||||||
class BinaryElementwiseOpModel(nn.Module):
|
|
||||||
|
|
||||||
def __init__(self, op, const):
|
|
||||||
super().__init__()
|
|
||||||
self.op = op
|
|
||||||
self.const = const
|
|
||||||
|
|
||||||
def forward(self, x1):
|
|
||||||
out = self.op(x1, self.const)
|
|
||||||
return out
|
|
||||||
|
|
||||||
physical_mesh_id = torch.arange(0, 4)
|
physical_mesh_id = torch.arange(0, 4)
|
||||||
mesh_shape = (2, 2)
|
mesh_shape = (2, 2)
|
||||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
|
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
|
||||||
model = BinaryElementwiseOpModel(op, other_dim).cuda()
|
if model_cls == BEOpModelWithNodeConst:
|
||||||
|
model = model_cls(op).cuda()
|
||||||
|
else:
|
||||||
|
model = model_cls(op, other_dim).cuda()
|
||||||
x1 = torch.rand(4, 4).cuda()
|
x1 = torch.rand(4, 4).cuda()
|
||||||
# the index of binary-elementwise node in computation graph
|
# the index of binary-elementwise node in computation graph
|
||||||
node_index = 1
|
node_index = 1
|
||||||
|
@ -159,9 +175,14 @@ def check_binary_elementwise_handler_with_int(rank, op, other_dim, world_size, p
|
||||||
tracer = ColoTracer()
|
tracer = ColoTracer()
|
||||||
meta_args = {'x1': torch.rand(4, 4).to('meta')}
|
meta_args = {'x1': torch.rand(4, 4).to('meta')}
|
||||||
graph = tracer.trace(model, meta_args=meta_args)
|
graph = tracer.trace(model, meta_args=meta_args)
|
||||||
|
print(graph)
|
||||||
|
# assert False
|
||||||
gm = ColoGraphModule(model, graph)
|
gm = ColoGraphModule(model, graph)
|
||||||
|
|
||||||
op_node = list(graph.nodes)[1]
|
if model_cls == BEOpModelWithNodeConst:
|
||||||
|
op_node = list(graph.nodes)[2]
|
||||||
|
else:
|
||||||
|
op_node = list(graph.nodes)[1]
|
||||||
strategies_vector = StrategiesVector(op_node)
|
strategies_vector = StrategiesVector(op_node)
|
||||||
|
|
||||||
# build handler
|
# build handler
|
||||||
|
@ -212,7 +233,7 @@ def check_binary_elementwise_handler_with_int(rank, op, other_dim, world_size, p
|
||||||
@parameterize('other_dim', [1, 2])
|
@parameterize('other_dim', [1, 2])
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
@rerun_if_address_is_in_use()
|
@rerun_if_address_is_in_use()
|
||||||
def test_binary_elementwise_handler(op, other_dim):
|
def test_binary_elementwise_handler_with_tensor(op, other_dim):
|
||||||
world_size = 4
|
world_size = 4
|
||||||
run_func_tensor = partial(check_binary_elementwise_handler_with_tensor,
|
run_func_tensor = partial(check_binary_elementwise_handler_with_tensor,
|
||||||
op=op,
|
op=op,
|
||||||
|
@ -220,8 +241,19 @@ def test_binary_elementwise_handler(op, other_dim):
|
||||||
world_size=world_size,
|
world_size=world_size,
|
||||||
port=free_port())
|
port=free_port())
|
||||||
mp.spawn(run_func_tensor, nprocs=world_size)
|
mp.spawn(run_func_tensor, nprocs=world_size)
|
||||||
|
|
||||||
|
|
||||||
|
@run_on_environment_flag(name='AUTO_PARALLEL')
|
||||||
|
@parameterize('op', [torch.add])
|
||||||
|
@parameterize('other_dim', [1, 2])
|
||||||
|
@parameterize('model_cls', [BEOpModelWithNodeConst, BEOpModelWithIntConst])
|
||||||
|
@pytest.mark.dist
|
||||||
|
@rerun_if_address_is_in_use()
|
||||||
|
def test_binary_elementwise_handler_with_int(op, model_cls, other_dim):
|
||||||
|
world_size = 4
|
||||||
run_func_int = partial(check_binary_elementwise_handler_with_int,
|
run_func_int = partial(check_binary_elementwise_handler_with_int,
|
||||||
op=op,
|
op=op,
|
||||||
|
model_cls=model_cls,
|
||||||
other_dim=other_dim,
|
other_dim=other_dim,
|
||||||
world_size=world_size,
|
world_size=world_size,
|
||||||
port=free_port())
|
port=free_port())
|
||||||
|
@ -229,4 +261,5 @@ def test_binary_elementwise_handler(op, other_dim):
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test_binary_elementwise_handler()
|
test_binary_elementwise_handler_with_tensor()
|
||||||
|
test_binary_elementwise_handler_with_int()
|
||||||
|
|
|
@ -90,7 +90,8 @@ def numerical_test_for_node_strategy(model: torch.nn.Module,
|
||||||
solver_options = SolverOptions()
|
solver_options = SolverOptions()
|
||||||
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
|
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
|
||||||
strategies_constructor.build_strategies_and_cost()
|
strategies_constructor.build_strategies_and_cost()
|
||||||
target_node = list(graph.nodes)[node_index]
|
target_node = [strategies_vector.node for strategies_vector in strategies_constructor.leaf_strategies
|
||||||
|
][node_index]
|
||||||
if node_type == 'normal':
|
if node_type == 'normal':
|
||||||
solution_len = len(strategies_constructor.leaf_strategies)
|
solution_len = len(strategies_constructor.leaf_strategies)
|
||||||
solution = [0] * solution_len
|
solution = [0] * solution_len
|
||||||
|
@ -112,7 +113,7 @@ def numerical_test_for_node_strategy(model: torch.nn.Module,
|
||||||
ret = solver.call_solver_serialized_args()
|
ret = solver.call_solver_serialized_args()
|
||||||
solution = list(ret[0])
|
solution = list(ret[0])
|
||||||
gm, sharding_spec_dict, origin_spec_dict, comm_actions_dict = runtime_preparation_pass(
|
gm, sharding_spec_dict, origin_spec_dict, comm_actions_dict = runtime_preparation_pass(
|
||||||
gm, solution, device_mesh)
|
gm, solution, device_mesh, strategies_constructor)
|
||||||
gm = runtime_apply_pass(gm)
|
gm = runtime_apply_pass(gm)
|
||||||
gm.recompile()
|
gm.recompile()
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue