mirror of https://github.com/hpcaitech/ColossalAI
[tensor] torch function return colotensor (#1229)
parent
5581170890
commit
a98319f023
|
@ -17,14 +17,12 @@ def register_elementwise_op(op):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
output = op(input_tensor, *args, **kwargs)
|
output = op(input_tensor, *args, **kwargs)
|
||||||
|
|
||||||
if isinstance(input_tensor, ColoTensor):
|
if isinstance(input_tensor, ColoTensor):
|
||||||
if not isinstance(output, torch.Tensor):
|
if not isinstance(output, torch.Tensor):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
return ColoTensor.from_torch_tensor(output,
|
return ColoTensor.from_torch_tensor(output,
|
||||||
spec=ColoTensorSpec(input_tensor.process_group,
|
spec=ColoTensorSpec(input_tensor.get_process_group(),
|
||||||
dist_attr=input_tensor.dist_spec,
|
dist_attr=input_tensor.dist_spec))
|
||||||
compute_attr=input_tensor.compute_spec))
|
|
||||||
|
|
||||||
|
|
||||||
# Tensor op
|
# Tensor op
|
||||||
|
|
|
@ -22,7 +22,7 @@ def colo_linear_1Drow(input_tensor: ColoTensor, weight: ColoTensor, bias: Option
|
||||||
assert not bias.has_compute_spec(), 'Invalid bias spec for 1Drow Linear op'
|
assert not bias.has_compute_spec(), 'Invalid bias spec for 1Drow Linear op'
|
||||||
output = output + bias
|
output = output + bias
|
||||||
|
|
||||||
pg = input_tensor.get_process_group()
|
pg = weight.get_process_group()
|
||||||
output = ColoTensor.from_torch_tensor(output, spec=ColoTensorSpec(pg, distspec.replicate()))
|
output = ColoTensor.from_torch_tensor(output, spec=ColoTensorSpec(pg, distspec.replicate()))
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
@ -61,6 +61,7 @@ def colo_linear_imp(input_tensor: GeneralTensor,
|
||||||
"""
|
"""
|
||||||
assert isinstance(weight, ColoTensor)
|
assert isinstance(weight, ColoTensor)
|
||||||
pg = weight.get_process_group()
|
pg = weight.get_process_group()
|
||||||
|
assert pg
|
||||||
input_tensor = convert_to_colo_tensor(input_tensor, pg)
|
input_tensor = convert_to_colo_tensor(input_tensor, pg)
|
||||||
bias = convert_to_colo_tensor(bias, pg)
|
bias = convert_to_colo_tensor(bias, pg)
|
||||||
# input_tensor, weight, bias = tuple(map(convert_to_colo_tensor, (input_tensor, weight, bias)))
|
# input_tensor, weight, bias = tuple(map(convert_to_colo_tensor, (input_tensor, weight, bias)))
|
||||||
|
@ -70,7 +71,7 @@ def colo_linear_imp(input_tensor: GeneralTensor,
|
||||||
if not weight.has_compute_spec(): # No Model Parallel Applied
|
if not weight.has_compute_spec(): # No Model Parallel Applied
|
||||||
assert weight.is_replicate(), 'Invalid weight spec for native Linear op'
|
assert weight.is_replicate(), 'Invalid weight spec for native Linear op'
|
||||||
assert bias is None or bias.is_replicate(), 'Invalid bias spec for native Linear op'
|
assert bias is None or bias.is_replicate(), 'Invalid bias spec for native Linear op'
|
||||||
ret_tensor = ColoTensor.from_torch_tensor(F.linear(input_tensor, weight, bias))
|
ret_tensor = ColoTensor.from_torch_tensor(F.linear(input_tensor, weight, bias), spec=ColoTensorSpec(pg))
|
||||||
elif weight.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied
|
elif weight.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied
|
||||||
if weight.is_shard_1dcol() and (bias is None or bias.is_replicate()):
|
if weight.is_shard_1dcol() and (bias is None or bias.is_replicate()):
|
||||||
mode = 'row'
|
mode = 'row'
|
||||||
|
|
|
@ -35,7 +35,7 @@ def colo_cross_entropy(input_tensor: GeneralTensor,
|
||||||
elif input_tensor.has_compute_spec(): # Single Model Parallel Applied
|
elif input_tensor.has_compute_spec(): # Single Model Parallel Applied
|
||||||
if input_tensor.is_shard_1dcol():
|
if input_tensor.is_shard_1dcol():
|
||||||
output = VocabParallelCrossEntropyLoss1D()(input_tensor, target)
|
output = VocabParallelCrossEntropyLoss1D()(input_tensor, target)
|
||||||
return ColoTensor.from_torch_tensor(output, ColoTensorSpec(pg))
|
return ColoTensor.from_torch_tensor(output, ColoTensorSpec(pg)).to_replicate()
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -11,12 +11,30 @@ from colossalai.tensor.distspec import _DistSpec, DistPlacementPattern
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
def _check_output(output):
|
def _convert_output(output, pg: ProcessGroup):
|
||||||
if not isinstance(output, torch.Tensor):
|
if type(output) == torch.Tensor:
|
||||||
raise RuntimeError
|
return ColoTensor.from_torch_tensor(output, ColoTensorSpec(pg))
|
||||||
elif isinstance(output, (list, tuple)):
|
elif isinstance(output, (list, tuple)):
|
||||||
output = type(output)(_check_output(o) for o in output)
|
return type(output)(_convert_output(o, pg) for o in output)
|
||||||
return output
|
else:
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def _scan_for_pg_from_args(args, kwargs) -> ProcessGroup:
|
||||||
|
for elem in args:
|
||||||
|
if isinstance(elem, ColoTensor):
|
||||||
|
pg = elem.get_process_group()
|
||||||
|
return pg
|
||||||
|
elif isinstance(elem, (list, tuple)):
|
||||||
|
pg = _scan_for_pg_from_args(elem, {})
|
||||||
|
if pg is not None:
|
||||||
|
return pg
|
||||||
|
print(type(elem), elem, isinstance(elem, (list, tuple)))
|
||||||
|
for k, v in kwargs:
|
||||||
|
if isinstance(v, ColoTensor):
|
||||||
|
pg = v.get_process_group()
|
||||||
|
return pg
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
class ColoTensor(torch.Tensor):
|
class ColoTensor(torch.Tensor):
|
||||||
|
@ -108,6 +126,7 @@ class ColoTensor(torch.Tensor):
|
||||||
dist_spec (_DistSpec): target dist spec.
|
dist_spec (_DistSpec): target dist spec.
|
||||||
"""
|
"""
|
||||||
assert isinstance(dist_spec, _DistSpec)
|
assert isinstance(dist_spec, _DistSpec)
|
||||||
|
assert self.process_group
|
||||||
self._convert_to_dist_spec(dist_spec)
|
self._convert_to_dist_spec(dist_spec)
|
||||||
|
|
||||||
def set_tensor_spec(self, dist_spec, compute_spec):
|
def set_tensor_spec(self, dist_spec, compute_spec):
|
||||||
|
@ -136,12 +155,11 @@ class ColoTensor(torch.Tensor):
|
||||||
if func in get_default_nowrap_functions():
|
if func in get_default_nowrap_functions():
|
||||||
return ret
|
return ret
|
||||||
else:
|
else:
|
||||||
# TODO(jiaruifang) its parallel Op's duty to convert output activations
|
pg = _scan_for_pg_from_args(args, kwargs)
|
||||||
return ret
|
return _convert_output(ret, pg)
|
||||||
# return _check_output(ret)
|
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f'ColoTensor: {super().__repr__()}'
|
return f'ColoTensor: {super().__repr__()}\n dist spec: {self.dist_spec}\n process group: {self.process_group}'
|
||||||
|
|
||||||
def _convert_to_dist_spec(self, dist_spec: _DistSpec) -> None:
|
def _convert_to_dist_spec(self, dist_spec: _DistSpec) -> None:
|
||||||
"""_convert_to_dist_spec
|
"""_convert_to_dist_spec
|
||||||
|
|
|
@ -19,6 +19,10 @@ class PyTorchProcessGroupDict(metaclass=SingletonMeta):
|
||||||
pg_key = (backend, rank_tuple)
|
pg_key = (backend, rank_tuple)
|
||||||
|
|
||||||
if pg_key not in self.dict:
|
if pg_key not in self.dict:
|
||||||
|
|
||||||
|
self.logger = get_dist_logger('ProcessGroup')
|
||||||
|
self.logger.info(f'NCCL initialize TP group on {rank_list}', ranks=[0])
|
||||||
|
|
||||||
self.dict[pg_key] = torch.distributed.new_group(ranks=rank_list, backend=backend)
|
self.dict[pg_key] = torch.distributed.new_group(ranks=rank_list, backend=backend)
|
||||||
return self.dict[pg_key]
|
return self.dict[pg_key]
|
||||||
|
|
||||||
|
@ -92,10 +96,6 @@ class ProcessGroup:
|
||||||
self._tp_process_group = PYTORCHPGDICT_.get(self._tp_rank_list, 'nccl')
|
self._tp_process_group = PYTORCHPGDICT_.get(self._tp_rank_list, 'nccl')
|
||||||
self._dp_process_group = PYTORCHPGDICT_.get(self._dp_rank_list, 'nccl')
|
self._dp_process_group = PYTORCHPGDICT_.get(self._dp_rank_list, 'nccl')
|
||||||
|
|
||||||
self.logger = get_dist_logger('ProcessGroup')
|
|
||||||
self.logger.info(
|
|
||||||
f'{self._rank} NCCL initialize TP group on {self._tp_rank_list}, DP group on {self._dp_rank_list}')
|
|
||||||
|
|
||||||
self._has_cpu_groups = False
|
self._has_cpu_groups = False
|
||||||
self._cpu_dp_process_group = None
|
self._cpu_dp_process_group = None
|
||||||
self._cpu_tp_process_group = None
|
self._cpu_tp_process_group = None
|
||||||
|
|
|
@ -113,6 +113,7 @@ def run_1d_hybrid_tp(model_name):
|
||||||
|
|
||||||
torch.distributed.broadcast(data, 0, group=pg.tp_process_group())
|
torch.distributed.broadcast(data, 0, group=pg.tp_process_group())
|
||||||
torch.distributed.broadcast(label, 0, group=pg.tp_process_group())
|
torch.distributed.broadcast(label, 0, group=pg.tp_process_group())
|
||||||
|
|
||||||
# Bcast rank0 data to all processes
|
# Bcast rank0 data to all processes
|
||||||
if criterion:
|
if criterion:
|
||||||
output = model(data)
|
output = model(data)
|
||||||
|
|
|
@ -39,7 +39,7 @@ def check_spec_eq(tensor, other):
|
||||||
assert isinstance(tensor, ColoTensor) and isinstance(other, ColoTensor)
|
assert isinstance(tensor, ColoTensor) and isinstance(other, ColoTensor)
|
||||||
for k in dir(tensor.dist_spec):
|
for k in dir(tensor.dist_spec):
|
||||||
if not k.startswith('__'):
|
if not k.startswith('__'):
|
||||||
assert hasattr(other.dist_spec, k)
|
assert hasattr(other.dist_spec, k), f"{k}"
|
||||||
assert getattr(tensor.dist_spec, k) == getattr(other.dist_spec, k)
|
assert getattr(tensor.dist_spec, k) == getattr(other.dist_spec, k)
|
||||||
|
|
||||||
|
|
||||||
|
@ -48,6 +48,7 @@ def check_element_wise_ops():
|
||||||
pg = ProcessGroup(tp_degree=world_size)
|
pg = ProcessGroup(tp_degree=world_size)
|
||||||
t = torch.rand(2, 2)
|
t = torch.rand(2, 2)
|
||||||
x = ColoTensor(t, spec=ColoTensorSpec(pg, distspec.shard([0], [pg.tp_world_size()])))
|
x = ColoTensor(t, spec=ColoTensorSpec(pg, distspec.shard([0], [pg.tp_world_size()])))
|
||||||
|
|
||||||
check_spec_eq(x, x.cuda())
|
check_spec_eq(x, x.cuda())
|
||||||
assert torch.equal(x.cuda(), t.cuda())
|
assert torch.equal(x.cuda(), t.cuda())
|
||||||
check_spec_eq(x, torch.abs(x))
|
check_spec_eq(x, torch.abs(x))
|
||||||
|
|
|
@ -49,6 +49,8 @@ def _run_operand():
|
||||||
|
|
||||||
t_ref_res = t_ref + t_ref
|
t_ref_res = t_ref + t_ref
|
||||||
t_res = t + t
|
t_res = t + t
|
||||||
|
|
||||||
|
assert isinstance(t_res, ColoTensor)
|
||||||
assert torch.allclose(t_ref_res, t_res)
|
assert torch.allclose(t_ref_res, t_res)
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue