mirror of https://github.com/hpcaitech/ColossalAI
[Tensor] rename some APIs in TensorSpec and Polish view unittest (#1176)
parent
dd0420909f
commit
0dd4e2bbfb
|
@ -72,10 +72,10 @@ def colo_addmm(input_tensor: GeneralTensor,
|
||||||
assert input_tensor.tensor_spec.is_replicate(), 'Invalid input spec for native addmm op'
|
assert input_tensor.tensor_spec.is_replicate(), 'Invalid input spec for native addmm op'
|
||||||
ret_tensor = ColoTensor.from_torch_tensor(torch.addmm(input_tensor, mat1, mat2, beta=beta, alpha=alpha))
|
ret_tensor = ColoTensor.from_torch_tensor(torch.addmm(input_tensor, mat1, mat2, beta=beta, alpha=alpha))
|
||||||
elif mat2.tensor_spec.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied
|
elif mat2.tensor_spec.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied
|
||||||
if mat2.tensor_spec.is_1D_row() and input_tensor.tensor_spec.is_replicate():
|
if mat2.tensor_spec.is_shard_1drow() and input_tensor.tensor_spec.is_replicate():
|
||||||
mode = 'row'
|
mode = 'row'
|
||||||
elif mat2.tensor_spec.is_1D_col() and (input_tensor.tensor_spec.is_1D_col()
|
elif mat2.tensor_spec.is_shard_1dcol() and (input_tensor.tensor_spec.is_shard_1dcol()
|
||||||
or input_tensor.tensor_spec.is_1D_row()):
|
or input_tensor.tensor_spec.is_shard_1drow()):
|
||||||
mode = 'col'
|
mode = 'col'
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
|
@ -32,6 +32,7 @@ def colo_embedding_1Dcol(input_tensor: ColoTensor,
|
||||||
output = ColoTensor.from_torch_tensor(output_parallel, spec=output_spec)
|
output = ColoTensor.from_torch_tensor(output_parallel, spec=output_spec)
|
||||||
|
|
||||||
compute_spec = weight.tensor_spec.compute_spec
|
compute_spec = weight.tensor_spec.compute_spec
|
||||||
|
|
||||||
if compute_spec.output_replicate:
|
if compute_spec.output_replicate:
|
||||||
return output.to_replicate()
|
return output.to_replicate()
|
||||||
else:
|
else:
|
||||||
|
@ -125,9 +126,9 @@ def colo_embedding(input_tensor: GeneralTensor,
|
||||||
scale_grad_by_freq=scale_grad_by_freq,
|
scale_grad_by_freq=scale_grad_by_freq,
|
||||||
sparse=sparse))
|
sparse=sparse))
|
||||||
elif weight.tensor_spec.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied
|
elif weight.tensor_spec.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied
|
||||||
if weight.tensor_spec.is_1D_row():
|
if weight.tensor_spec.is_shard_1drow():
|
||||||
mode = 'row'
|
mode = 'row'
|
||||||
elif weight.tensor_spec.is_1D_col():
|
elif weight.tensor_spec.is_shard_1dcol():
|
||||||
mode = 'col'
|
mode = 'col'
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
|
@ -104,7 +104,7 @@ def colo_embedding_bag(input_tensor: GeneralTensor,
|
||||||
include_last_offset=include_last_offset,
|
include_last_offset=include_last_offset,
|
||||||
padding_idx=padding_idx))
|
padding_idx=padding_idx))
|
||||||
elif weight.tensor_spec.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied
|
elif weight.tensor_spec.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied
|
||||||
if weight.tensor_spec.is_1D_col():
|
if weight.tensor_spec.is_shard_1dcol():
|
||||||
tp_mode = 'col'
|
tp_mode = 'col'
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
|
@ -71,10 +71,10 @@ def colo_linear_imp(input_tensor: GeneralTensor,
|
||||||
assert bias is None or bias.tensor_spec.is_replicate(), 'Invalid bias spec for native Linear op'
|
assert bias is None or bias.tensor_spec.is_replicate(), 'Invalid bias spec for native Linear op'
|
||||||
ret_tensor = ColoTensor.from_torch_tensor(F.linear(input_tensor, weight, bias))
|
ret_tensor = ColoTensor.from_torch_tensor(F.linear(input_tensor, weight, bias))
|
||||||
elif weight.tensor_spec.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied
|
elif weight.tensor_spec.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied
|
||||||
if weight.tensor_spec.is_1D_col() and (bias is None or bias.tensor_spec.is_replicate()):
|
if weight.tensor_spec.is_shard_1dcol() and (bias is None or bias.tensor_spec.is_replicate()):
|
||||||
mode = 'row'
|
mode = 'row'
|
||||||
elif weight.tensor_spec.is_1D_row() and (bias is None or bias.tensor_spec.is_1D_row()
|
elif weight.tensor_spec.is_shard_1drow() and (bias is None or bias.tensor_spec.is_shard_1drow()
|
||||||
or bias.tensor_spec.is_1D_col()):
|
or bias.tensor_spec.is_shard_1dcol()):
|
||||||
mode = 'col'
|
mode = 'col'
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(f"the weight or bias tensor spec is not valid, weight {weight.tensor_spec}, bias {bias}")
|
raise RuntimeError(f"the weight or bias tensor spec is not valid, weight {weight.tensor_spec}, bias {bias}")
|
||||||
|
|
|
@ -29,7 +29,7 @@ def colo_cross_entropy(input_tensor: GeneralTensor,
|
||||||
label_smoothing=label_smoothing)
|
label_smoothing=label_smoothing)
|
||||||
return ColoTensor.from_torch_tensor(output)
|
return ColoTensor.from_torch_tensor(output)
|
||||||
elif input_tensor.has_compute_spec(): # Single Model Parallel Applied
|
elif input_tensor.has_compute_spec(): # Single Model Parallel Applied
|
||||||
if input_tensor.tensor_spec.is_1D_col():
|
if input_tensor.tensor_spec.is_shard_1dcol():
|
||||||
output = VocabParallelCrossEntropyLoss1D()(input_tensor, target)
|
output = VocabParallelCrossEntropyLoss1D()(input_tensor, target)
|
||||||
return ColoTensor.from_torch_tensor(output)
|
return ColoTensor.from_torch_tensor(output)
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -116,6 +116,7 @@ class Chunk:
|
||||||
if self.is_src_rank:
|
if self.is_src_rank:
|
||||||
self._payload[self.utilized_size:new_utilized_size].copy_(tensor.flatten())
|
self._payload[self.utilized_size:new_utilized_size].copy_(tensor.flatten())
|
||||||
tensor_state = TensorState.HOLD
|
tensor_state = TensorState.HOLD
|
||||||
|
assert type(self._payload) == torch.Tensor, "copy_tensor_to_chunk_slice must use a torch tensor"
|
||||||
tensor.data = self._payload[self.utilized_size:new_utilized_size].view(tensor.shape)
|
tensor.data = self._payload[self.utilized_size:new_utilized_size].view(tensor.shape)
|
||||||
else:
|
else:
|
||||||
tensor.storage().resize_(0)
|
tensor.storage().resize_(0)
|
||||||
|
@ -131,6 +132,7 @@ class Chunk:
|
||||||
self._update_tensors_state(TensorState.FREE)
|
self._update_tensors_state(TensorState.FREE)
|
||||||
|
|
||||||
def _update_tensors_ptr(self) -> None:
|
def _update_tensors_ptr(self) -> None:
|
||||||
|
assert type(self._payload) == torch.Tensor
|
||||||
for tensor, tensor_info in self.tensors_info.items():
|
for tensor, tensor_info in self.tensors_info.items():
|
||||||
tensor.data = self._payload[tensor_info.offset:tensor_info.end].view(tensor.shape)
|
tensor.data = self._payload[tensor_info.offset:tensor_info.end].view(tensor.shape)
|
||||||
|
|
||||||
|
@ -228,7 +230,7 @@ class Chunk:
|
||||||
data_slice (torch.Tensor): the tensor to be copied to the chunk
|
data_slice (torch.Tensor): the tensor to be copied to the chunk
|
||||||
"""
|
"""
|
||||||
tensor_info = self.tensors_info[tensor]
|
tensor_info = self.tensors_info[tensor]
|
||||||
self._payload[tensor_info.offset:tensor_info.end].copy_(data_slice.view(-1))
|
self._payload[tensor_info.offset:tensor_info.end].copy_(data_slice.flatten())
|
||||||
tensor.data = self._payload[tensor_info.offset:tensor_info.end].view(tensor.shape)
|
tensor.data = self._payload[tensor_info.offset:tensor_info.end].view(tensor.shape)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|
|
@ -54,5 +54,5 @@ def shard(process_group: ProcessGroup, dims: List[int], num_partitions: List[int
|
||||||
assert process_group is not None
|
assert process_group is not None
|
||||||
assert isinstance(dims, list) and isinstance(num_partitions, list)
|
assert isinstance(dims, list) and isinstance(num_partitions, list)
|
||||||
assert len(dims) == len(num_partitions)
|
assert len(dims) == len(num_partitions)
|
||||||
assert prod(num_partitions) == process_group.size()
|
assert prod(num_partitions) == process_group.size(), f"{num_partitions} {process_group.size()}"
|
||||||
return _DistSpec(DistPlacementPattern.SHARD, process_group, dims=tuple(dims), num_partitions=tuple(num_partitions))
|
return _DistSpec(DistPlacementPattern.SHARD, process_group, dims=tuple(dims), num_partitions=tuple(num_partitions))
|
||||||
|
|
|
@ -32,11 +32,11 @@ class TensorSpec(object):
|
||||||
and self.dist_spec.num_partitions[0] == 1) \
|
and self.dist_spec.num_partitions[0] == 1) \
|
||||||
or (self.dist_spec.process_group.size() == 1)
|
or (self.dist_spec.process_group.size() == 1)
|
||||||
|
|
||||||
def is_1D_col(self):
|
def is_shard_1dcol(self):
|
||||||
return self.dist_spec.placement == DistPlacementPattern.SHARD \
|
return self.dist_spec.placement == DistPlacementPattern.SHARD \
|
||||||
and len(self.dist_spec.dims) == 1 and self.dist_spec.dims[0] == -1
|
and len(self.dist_spec.dims) == 1 and self.dist_spec.dims[0] == -1
|
||||||
|
|
||||||
def is_1D_row(self):
|
def is_shard_1drow(self):
|
||||||
return self.dist_spec.placement == DistPlacementPattern.SHARD \
|
return self.dist_spec.placement == DistPlacementPattern.SHARD \
|
||||||
and len(self.dist_spec.dims) == 1 and self.dist_spec.dims[0] == 0
|
and len(self.dist_spec.dims) == 1 and self.dist_spec.dims[0] == 0
|
||||||
|
|
||||||
|
|
|
@ -63,13 +63,19 @@ def test_operand():
|
||||||
def _run_view(world_size):
|
def _run_view(world_size):
|
||||||
t_ref = torch.randn(4, 5)
|
t_ref = torch.randn(4, 5)
|
||||||
t = ColoTensor.from_torch_tensor(
|
t = ColoTensor.from_torch_tensor(
|
||||||
t_ref, TensorSpec(distspec.shard(process_group=gpc.get_group(ParallelMode.DATA), dims=[0], num_partitions=[2])))
|
t_ref,
|
||||||
|
TensorSpec(distspec.shard(process_group=gpc.get_group(ParallelMode.DATA), dims=[0],
|
||||||
|
num_partitions=[world_size])))
|
||||||
|
|
||||||
assert t.size()[0] == 4 * world_size
|
assert t.size()[0] == 4 * world_size
|
||||||
assert t.size(1) == 5
|
assert t.size(1) == 5
|
||||||
assert t.size() == torch.Size([4 * world_size, 5])
|
assert t.size() == torch.Size([4 * world_size, 5])
|
||||||
|
|
||||||
|
t.view_base(4 * 5)
|
||||||
|
assert t.tensor_spec.dist_spec.placement.value == 's'
|
||||||
|
|
||||||
t = t.view(4 * 5 * world_size)
|
t = t.view(4 * 5 * world_size)
|
||||||
|
assert t.tensor_spec.dist_spec.placement.value == 'r'
|
||||||
assert t.shape == torch.Size([4 * 5 * world_size])
|
assert t.shape == torch.Size([4 * 5 * world_size])
|
||||||
|
|
||||||
|
|
||||||
|
@ -100,11 +106,10 @@ def run_dist_tests(rank, world_size, port):
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
@pytest.mark.parametrize('world_size', [1, 2])
|
@pytest.mark.parametrize('world_size', [1, 2])
|
||||||
@rerun_if_address_is_in_use()
|
@rerun_if_address_is_in_use()
|
||||||
def _test_dist_cases(world_size):
|
def test_dist_cases(world_size):
|
||||||
run_func = partial(run_dist_tests, world_size=world_size, port=free_port())
|
run_func = partial(run_dist_tests, world_size=world_size, port=free_port())
|
||||||
mp.spawn(run_func, nprocs=world_size)
|
mp.spawn(run_func, nprocs=world_size)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
# _test_dist_init(4)
|
test_dist_cases(2)
|
||||||
_test_dist_cases(2)
|
|
||||||
|
|
Loading…
Reference in New Issue