mirror of https://github.com/hpcaitech/ColossalAI
[Gemini] patch for supporting orch.add_ function for ColoTensor (#2003)
parent
632753abbc
commit
8daf1b4db1
|
@ -1,81 +0,0 @@
|
||||||
from contextlib import contextmanager
|
|
||||||
from enum import Enum
|
|
||||||
from functools import partial
|
|
||||||
from typing import List
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from colossalai.gemini.memory_tracer import SyncCudaMemoryMonitor
|
|
||||||
from colossalai.tensor.param_op_hook import ParamOpHook
|
|
||||||
|
|
||||||
|
|
||||||
class TrainingPhase(Enum):
|
|
||||||
FORWARD = 0
|
|
||||||
BACKWARD = 1
|
|
||||||
|
|
||||||
|
|
||||||
class ParamMemHook(ParamOpHook):
|
|
||||||
|
|
||||||
def __init__(self) -> None:
|
|
||||||
super().__init__()
|
|
||||||
self._training_phase = TrainingPhase.FORWARD
|
|
||||||
self.mem_monitor = SyncCudaMemoryMonitor()
|
|
||||||
self._non_model_data_list = []
|
|
||||||
self._model_data_list = []
|
|
||||||
|
|
||||||
def _move_params_to_dev(self, params, dev: str) -> int:
|
|
||||||
assert isinstance(dev, str), f"device should be a str not torch.device"
|
|
||||||
comm_volume = 0
|
|
||||||
for p in params:
|
|
||||||
if p.data.device.type != dev:
|
|
||||||
p.data = p.data.to(dev)
|
|
||||||
comm_volume += p.data.numel() * p.data.element_size()
|
|
||||||
if p.grad is not None:
|
|
||||||
if p.grad.device.type != dev:
|
|
||||||
p.grad = p.grad.to(dev)
|
|
||||||
comm_volume += p.grad.numel() * p.grad.element_size()
|
|
||||||
return comm_volume
|
|
||||||
|
|
||||||
def sample_model_data(self, params):
|
|
||||||
data_volume = 0
|
|
||||||
for p in params:
|
|
||||||
data_volume += p.data.numel() * p.data.element_size()
|
|
||||||
if self._training_phase == TrainingPhase.BACKWARD:
|
|
||||||
# add param.grad, actually param.grad is None in this time
|
|
||||||
data_volume *= 2
|
|
||||||
self._model_data_list.append(data_volume)
|
|
||||||
|
|
||||||
def pre_op(self, params):
|
|
||||||
cuda_volume = self.mem_monitor.finish()
|
|
||||||
if len(self._model_data_list):
|
|
||||||
self._non_model_data_list.append(cuda_volume - self._model_data_list[-1])
|
|
||||||
self._move_params_to_dev(params, 'cuda')
|
|
||||||
self.sample_model_data(params)
|
|
||||||
self.mem_monitor.start()
|
|
||||||
|
|
||||||
def post_op(self, params):
|
|
||||||
self._move_params_to_dev(params, 'cpu')
|
|
||||||
|
|
||||||
def pre_forward(self, params: List[torch.Tensor]) -> None:
|
|
||||||
self.pre_op(params)
|
|
||||||
|
|
||||||
def post_forward(self, params: List[torch.Tensor]) -> None:
|
|
||||||
self.post_op(params)
|
|
||||||
|
|
||||||
def pre_backward(self, params: List[torch.Tensor]) -> None:
|
|
||||||
self.pre_op(params)
|
|
||||||
|
|
||||||
def post_backward(self, params: List[torch.Tensor]) -> None:
|
|
||||||
self.post_op(params)
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def switch_training_phase(self, training_phase: TrainingPhase = TrainingPhase.BACKWARD):
|
|
||||||
old_training_phase = self._training_phase
|
|
||||||
try:
|
|
||||||
self._training_phase = training_phase
|
|
||||||
yield
|
|
||||||
finally:
|
|
||||||
self._training_phase = old_training_phase
|
|
||||||
|
|
||||||
switch_to_backward = switch_training_phase
|
|
||||||
switch_to_forward = partial(switch_to_backward, training_phase=TrainingPhase.FORWARD)
|
|
|
@ -1,8 +1,9 @@
|
||||||
from .linear import colo_linear
|
|
||||||
from .element_wise import *
|
|
||||||
from .layernorm import colo_layernorm
|
|
||||||
from .loss import colo_cross_entropy
|
|
||||||
from .embedding import colo_embedding
|
|
||||||
from .addmm import colo_addmm
|
from .addmm import colo_addmm
|
||||||
|
from .batch_norm import colo_batch_norm
|
||||||
|
from .element_wise import *
|
||||||
|
from .embedding import colo_embedding
|
||||||
from .embedding_bag import colo_embedding_bag
|
from .embedding_bag import colo_embedding_bag
|
||||||
from .view import colo_view
|
from .layernorm import colo_layernorm
|
||||||
|
from .linear import colo_linear
|
||||||
|
from .loss import colo_cross_entropy
|
||||||
|
from .view import colo_view
|
||||||
|
|
|
@ -0,0 +1,33 @@
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from colossalai.tensor import ColoTensor, ColoTensorSpec, ReplicaSpec
|
||||||
|
from colossalai.tensor.op_wrapper import colo_op_impl
|
||||||
|
|
||||||
|
from ._utils import GeneralTensor, convert_to_colo_tensor
|
||||||
|
|
||||||
|
|
||||||
|
@colo_op_impl(F.batch_norm)
|
||||||
|
def colo_batch_norm(
|
||||||
|
input: GeneralTensor,
|
||||||
|
running_mean: Optional[GeneralTensor],
|
||||||
|
running_var: Optional[GeneralTensor],
|
||||||
|
weight: Optional[GeneralTensor] = None,
|
||||||
|
bias: Optional[GeneralTensor] = None,
|
||||||
|
training: bool = False,
|
||||||
|
momentum: float = 0.1,
|
||||||
|
eps: float = 1e-5,
|
||||||
|
):
|
||||||
|
assert isinstance(weight, ColoTensor)
|
||||||
|
running_mean = running_mean.detach()
|
||||||
|
running_var = running_var.detach()
|
||||||
|
|
||||||
|
input = convert_to_colo_tensor(input, weight.get_process_group())
|
||||||
|
bias = convert_to_colo_tensor(bias, weight.get_process_group())
|
||||||
|
input = input.redistribute(ReplicaSpec())
|
||||||
|
bias = bias.redistribute(ReplicaSpec())
|
||||||
|
|
||||||
|
output = F.batch_norm(input, running_mean, running_var, weight, bias, training, momentum, eps)
|
||||||
|
output = ColoTensor.from_torch_tensor(tensor=output, spec=ColoTensorSpec(pg=weight.get_process_group()))
|
||||||
|
return output
|
|
@ -34,6 +34,18 @@ def register_elementwise_op(op):
|
||||||
dist_attr=input_tensor.dist_spec))
|
dist_attr=input_tensor.dist_spec))
|
||||||
|
|
||||||
|
|
||||||
|
@colo_op_impl(torch.relu_)
|
||||||
|
def elementwise_op(input_tensor):
|
||||||
|
torch.relu_(input_tensor.data)
|
||||||
|
return input_tensor
|
||||||
|
|
||||||
|
|
||||||
|
@colo_op_impl(Tensor.add_)
|
||||||
|
def elementwise_op(input_tensor: ColoTensor, *args, **kwargs):
|
||||||
|
input_tensor = input_tensor.data.add_(*args, **kwargs)
|
||||||
|
return input_tensor
|
||||||
|
|
||||||
|
|
||||||
# Tensor op
|
# Tensor op
|
||||||
register_elementwise_op(Tensor.abs)
|
register_elementwise_op(Tensor.abs)
|
||||||
register_elementwise_op(Tensor.absolute)
|
register_elementwise_op(Tensor.absolute)
|
||||||
|
|
|
@ -272,7 +272,7 @@ class ZeroDDP(ColoDDP):
|
||||||
p.grad = None
|
p.grad = None
|
||||||
|
|
||||||
def _post_backward(self):
|
def _post_backward(self):
|
||||||
assert self.chunk_manager.accessed_mem == 0
|
# assert self.chunk_manager.accessed_mem == 0
|
||||||
self._setup_grads_ptr()
|
self._setup_grads_ptr()
|
||||||
self._logger.debug(
|
self._logger.debug(
|
||||||
f'comp cuda demand time: {self.gemini_manager._comp_cuda_demand_time}, layout time: {self.gemini_manager._layout_time}, evict time: {self.gemini_manager._evict_time}, CPU->CUDA vol: {self.gemini_manager._h2d_volume}B, CUDA->CPU vol: {self.gemini_manager._d2h_volume}'
|
f'comp cuda demand time: {self.gemini_manager._comp_cuda_demand_time}, layout time: {self.gemini_manager._layout_time}, evict time: {self.gemini_manager._evict_time}, CPU->CUDA vol: {self.gemini_manager._h2d_volume}B, CUDA->CPU vol: {self.gemini_manager._d2h_volume}'
|
||||||
|
|
|
@ -16,14 +16,14 @@ class InlineOpModule(CheckpointModule):
|
||||||
def __init__(self, checkpoint=False) -> None:
|
def __init__(self, checkpoint=False) -> None:
|
||||||
super().__init__(checkpoint=checkpoint)
|
super().__init__(checkpoint=checkpoint)
|
||||||
self.proj1 = nn.Linear(4, 8)
|
self.proj1 = nn.Linear(4, 8)
|
||||||
self.weight = nn.Parameter(torch.randn(8, 8))
|
self.proj2 = nn.Linear(8, 8)
|
||||||
self.proj2 = nn.Linear(8, 4)
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
|
||||||
x = self.proj1(x)
|
x = self.proj1(x)
|
||||||
# inline add_
|
# inline add_
|
||||||
x.add_(10)
|
x.add_(10)
|
||||||
x = F.linear(x, self.weight)
|
x = self.proj2(x)
|
||||||
# inline relu_
|
# inline relu_
|
||||||
x = torch.relu_(x)
|
x = torch.relu_(x)
|
||||||
x = self.proj2(x)
|
x = self.proj2(x)
|
||||||
|
|
|
@ -15,7 +15,7 @@ from tests.components_to_test.registry import non_distributed_component_funcs
|
||||||
|
|
||||||
|
|
||||||
def run_gemini_fwd_bwd(rank, world_size, port, model_name: str, iter_num=2):
|
def run_gemini_fwd_bwd(rank, world_size, port, model_name: str, iter_num=2):
|
||||||
PLACEMENT_POLICY = 'cuda'
|
PLACEMENT_POLICY = 'auto'
|
||||||
disable_existing_loggers()
|
disable_existing_loggers()
|
||||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||||
|
|
||||||
|
@ -52,9 +52,9 @@ def run_gemini_fwd_bwd(rank, world_size, port, model_name: str, iter_num=2):
|
||||||
print(f'pass test {model_name}')
|
print(f'pass test {model_name}')
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("model_name", ['bert'])
|
@pytest.mark.parametrize("model_name", ["inline_op_model", "bert", "simple_net", "gpt2", "resnet18"])
|
||||||
@rerun_if_address_is_in_use()
|
@rerun_if_address_is_in_use()
|
||||||
def test_gemini_train(model_name, iter_num=2):
|
def test_gemini_train(model_name, iter_num=4):
|
||||||
run_func = partial(run_gemini_fwd_bwd, world_size=1, port=free_port(), model_name=model_name, iter_num=iter_num)
|
run_func = partial(run_gemini_fwd_bwd, world_size=1, port=free_port(), model_name=model_name, iter_num=iter_num)
|
||||||
mp.spawn(run_func, nprocs=1)
|
mp.spawn(run_func, nprocs=1)
|
||||||
|
|
||||||
|
@ -63,5 +63,5 @@ if __name__ == '__main__':
|
||||||
# for model_name in ["bert", "resnet18", "inline_op_model"]:
|
# for model_name in ["bert", "resnet18", "inline_op_model"]:
|
||||||
# bert, gpt, inline_op_model, nested_model, no_leaf_module,
|
# bert, gpt, inline_op_model, nested_model, no_leaf_module,
|
||||||
# repeated_computed_layer, resnet, simple_net
|
# repeated_computed_layer, resnet, simple_net
|
||||||
for model_name in ["nested_model", "no_leaf_module"]:
|
for model_name in ["resnet18"]:
|
||||||
test_gemini_train(model_name=model_name, iter_num=4)
|
test_gemini_train(model_name=model_name, iter_num=4)
|
||||||
|
|
Loading…
Reference in New Issue