mirror of https://github.com/hpcaitech/ColossalAI
[zero] update zero context init with the updated test utils (#327)
parent
6268446b81
commit
11bddb6e55
|
@ -1,4 +1,3 @@
|
||||||
from re import S
|
|
||||||
from colossalai.context.parallel_mode import ParallelMode
|
from colossalai.context.parallel_mode import ParallelMode
|
||||||
import torch
|
import torch
|
||||||
from . import BaseOpHook
|
from . import BaseOpHook
|
||||||
|
@ -7,7 +6,7 @@ from colossalai.registry import OPHOOKS
|
||||||
from colossalai.logging import get_dist_logger
|
from colossalai.logging import get_dist_logger
|
||||||
from time import sleep, time
|
from time import sleep, time
|
||||||
import pickle
|
import pickle
|
||||||
from typing import Union, Optional
|
from typing import Optional
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
|
|
||||||
|
|
||||||
|
@ -19,12 +18,13 @@ def get_cuda_memory_used(device: Optional[torch.device]) -> int:
|
||||||
"""
|
"""
|
||||||
ret: int = torch.cuda.memory_allocated(device)
|
ret: int = torch.cuda.memory_allocated(device)
|
||||||
# get the peak memory to report correct data, so reset the counter for the next call
|
# get the peak memory to report correct data, so reset the counter for the next call
|
||||||
if hasattr(torch.cuda, "reset_peak_memory_stats"): # pytorch 1.4+
|
if hasattr(torch.cuda, "reset_peak_memory_stats"): # pytorch 1.4+
|
||||||
torch.cuda.reset_peak_memory_stats(device)
|
torch.cuda.reset_peak_memory_stats(device)
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
class AsyncMemoryMonitor:
|
class AsyncMemoryMonitor:
|
||||||
|
|
||||||
def __init__(self, power=10):
|
def __init__(self, power=10):
|
||||||
"""
|
"""
|
||||||
An Async Mem Monitor runing during computing.
|
An Async Mem Monitor runing during computing.
|
||||||
|
@ -81,7 +81,7 @@ class AsyncMemoryMonitor:
|
||||||
def save(self, filename):
|
def save(self, filename):
|
||||||
with open(filename, "wb") as f:
|
with open(filename, "wb") as f:
|
||||||
pickle.dump(self.state_dict(), f)
|
pickle.dump(self.state_dict(), f)
|
||||||
|
|
||||||
def clear(self):
|
def clear(self):
|
||||||
self.mem_stats.clear()
|
self.mem_stats.clear()
|
||||||
self.time_stamps.clear()
|
self.time_stamps.clear()
|
||||||
|
@ -92,7 +92,7 @@ class MemTracerOpHook(BaseOpHook):
|
||||||
'''
|
'''
|
||||||
Collect GPU memory usage information
|
Collect GPU memory usage information
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
warmup (int): This parameter indicates how many iterations to truncate
|
warmup (int): This parameter indicates how many iterations to truncate
|
||||||
before profiling, e.g. set to 5 and the data will start from 6-th iteration
|
before profiling, e.g. set to 5 and the data will start from 6-th iteration
|
||||||
refreshrate (int): This parameter decides the frequency of write file.
|
refreshrate (int): This parameter decides the frequency of write file.
|
||||||
|
@ -106,6 +106,7 @@ class MemTracerOpHook(BaseOpHook):
|
||||||
_data_prefix (string): the prefix of the stats data file
|
_data_prefix (string): the prefix of the stats data file
|
||||||
_rank (int): the rank of current node
|
_rank (int): the rank of current node
|
||||||
'''
|
'''
|
||||||
|
|
||||||
def __init__(self, warmup: int = 50, refreshrate: int = 10, data_prefix: str = "memstats"):
|
def __init__(self, warmup: int = 50, refreshrate: int = 10, data_prefix: str = "memstats"):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.async_mem_monitor = AsyncMemoryMonitor()
|
self.async_mem_monitor = AsyncMemoryMonitor()
|
||||||
|
@ -128,7 +129,7 @@ class MemTracerOpHook(BaseOpHook):
|
||||||
@property
|
@property
|
||||||
def refreshrate(self) -> int:
|
def refreshrate(self) -> int:
|
||||||
return self._refreshrate
|
return self._refreshrate
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def warmup(self) -> int:
|
def warmup(self) -> int:
|
||||||
return self._warmup
|
return self._warmup
|
||||||
|
@ -178,8 +179,7 @@ class MemTracerOpHook(BaseOpHook):
|
||||||
# every `refreshrate` times, refresh the file
|
# every `refreshrate` times, refresh the file
|
||||||
if self.valid_iter != 0 and self.valid_iter % self.refreshrate == 0:
|
if self.valid_iter != 0 and self.valid_iter % self.refreshrate == 0:
|
||||||
# output file info
|
# output file info
|
||||||
self._logger.info(
|
self._logger.info(f'dump a memory statistics as pickle to {self._dataprefix}-{self._rank}.pkl')
|
||||||
f'dump a memory statistics as pickle to {self._dataprefix}-{self._rank}.pkl')
|
|
||||||
self.save_results()
|
self.save_results()
|
||||||
self._count += 1
|
self._count += 1
|
||||||
self._logger.debug(f'data file has been refreshed {self._count} times')
|
self._logger.debug(f'data file has been refreshed {self._count} times')
|
||||||
|
|
|
@ -82,25 +82,31 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
||||||
3. Shard the param and grad according to flags.
|
3. Shard the param and grad according to flags.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(self,
|
||||||
self,
|
convert_fp16: bool,
|
||||||
convert_fp16: bool,
|
convert_cuda: bool,
|
||||||
convert_cuda: bool,
|
shard_strategy: BaseShardStrategy,
|
||||||
shard_strategy: BaseShardStrategy,
|
shard_param: bool = False,
|
||||||
shard_param: bool = False,
|
shard_grad: bool = False,
|
||||||
shard_grad: bool = False,
|
rm_torch_payload_on_the_fly=False):
|
||||||
):
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.convert_fp16 = convert_fp16
|
self.convert_fp16 = convert_fp16
|
||||||
self.convert_cuda = convert_cuda
|
self.convert_cuda = convert_cuda
|
||||||
self.shard_param = shard_param
|
self.shard_param = shard_param
|
||||||
self.shard_grad = shard_grad
|
self.shard_grad = shard_grad
|
||||||
self.shard_strategy = shard_strategy
|
self.shard_strategy = shard_strategy
|
||||||
|
self.rm_torch_payload_on_the_fly = rm_torch_payload_on_the_fly
|
||||||
|
self.initialized_param_list = []
|
||||||
|
|
||||||
def _post_context_exec(self):
|
def _post_context_exec(self):
|
||||||
"""The callback function when the context exits.
|
"""The callback function when the context exits.
|
||||||
"""
|
"""
|
||||||
pass
|
if not self.rm_torch_payload_on_the_fly:
|
||||||
|
for param in self.initialized_param_list:
|
||||||
|
assert hasattr(param, 'ca_attr')
|
||||||
|
param.ca_attr.remove_torch_payload()
|
||||||
|
|
||||||
|
del self.initialized_param_list
|
||||||
|
|
||||||
def _post_init_method(self, module):
|
def _post_init_method(self, module):
|
||||||
r"""The function to call at the end of the constructor of each nn.Module.
|
r"""The function to call at the end of the constructor of each nn.Module.
|
||||||
|
@ -121,7 +127,10 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
||||||
if param.grad is not None:
|
if param.grad is not None:
|
||||||
param.grad = param.grad.to(torch.half).to(target_device)
|
param.grad = param.grad.to(torch.half).to(target_device)
|
||||||
|
|
||||||
param.ca_attr = ShardedParamV2(param)
|
param.ca_attr = ShardedParamV2(param, rm_torch_payload=self.rm_torch_payload_on_the_fly)
|
||||||
|
|
||||||
|
self.initialized_param_list.append(param)
|
||||||
|
|
||||||
if self.shard_param:
|
if self.shard_param:
|
||||||
self.shard_strategy.shard(tensor_list=[param.ca_attr._data_sharded_tensor])
|
self.shard_strategy.shard(tensor_list=[param.ca_attr._data_sharded_tensor])
|
||||||
if param.ca_attr.grad and self.shard_grad:
|
if param.ca_attr.grad and self.shard_grad:
|
||||||
|
|
|
@ -7,6 +7,11 @@ from typing import List, Optional
|
||||||
class BaseShardStrategy(ABC):
|
class BaseShardStrategy(ABC):
|
||||||
|
|
||||||
def __init__(self, process_group: Optional[dist.ProcessGroup] = None) -> None:
|
def __init__(self, process_group: Optional[dist.ProcessGroup] = None) -> None:
|
||||||
|
"""Abstract Shard Strategy. Use to shard a tensors on multiple GPUs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
process_group (Optional[dist.ProcessGroup], optional): the process group. Defaults to None.
|
||||||
|
"""
|
||||||
self.process_group = process_group
|
self.process_group = process_group
|
||||||
self.world_size = dist.get_world_size(self.process_group)
|
self.world_size = dist.get_world_size(self.process_group)
|
||||||
self.local_rank = dist.get_rank(self.process_group)
|
self.local_rank = dist.get_rank(self.process_group)
|
||||||
|
@ -14,14 +19,8 @@ class BaseShardStrategy(ABC):
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def shard(self, tensor_list: List[ShardedTensor]):
|
def shard(self, tensor_list: List[ShardedTensor]):
|
||||||
r"""
|
|
||||||
sharded the memory of tensor on multiple processes.
|
|
||||||
"""
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def gather(self, tensor_list: List[ShardedTensor]):
|
def gather(self, tensor_list: List[ShardedTensor]):
|
||||||
r"""
|
|
||||||
duplicate tensor payload on each processes.
|
|
||||||
"""
|
|
||||||
pass
|
pass
|
||||||
|
|
|
@ -10,7 +10,10 @@ from typing import Union, Tuple, Optional
|
||||||
|
|
||||||
class ShardedParamV2(object):
|
class ShardedParamV2(object):
|
||||||
|
|
||||||
def __init__(self, param: torch.nn.Parameter, process_group: Optional[dist.ProcessGroup] = None) -> None:
|
def __init__(self,
|
||||||
|
param: torch.nn.Parameter,
|
||||||
|
process_group: Optional[dist.ProcessGroup] = None,
|
||||||
|
rm_torch_payload=False) -> None:
|
||||||
self._data_sharded_tensor = ShardedTensor(param.data, process_group)
|
self._data_sharded_tensor = ShardedTensor(param.data, process_group)
|
||||||
if param.requires_grad and param.grad is not None:
|
if param.requires_grad and param.grad is not None:
|
||||||
self._grad_sharded_tensor = ShardedTensor(param.grad, process_group)
|
self._grad_sharded_tensor = ShardedTensor(param.grad, process_group)
|
||||||
|
@ -19,7 +22,16 @@ class ShardedParamV2(object):
|
||||||
self._grad_sharded_tensor = None
|
self._grad_sharded_tensor = None
|
||||||
|
|
||||||
# make sure the shared param is the only owner of payload
|
# make sure the shared param is the only owner of payload
|
||||||
param.data = torch.empty([], dtype=param.dtype, device=param.device)
|
# The param.data maybe used to init the other part of the model.
|
||||||
|
# For example: File "resnet.py", line 190, in __init__
|
||||||
|
# nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||||
|
# So we can not empty the .data at this time
|
||||||
|
self.param = param
|
||||||
|
if rm_torch_payload:
|
||||||
|
self.remove_torch_payload()
|
||||||
|
|
||||||
|
def remove_torch_payload(self):
|
||||||
|
self.param.data = torch.empty([], dtype=self.param.dtype, device=self.param.device)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def data(self):
|
def data(self):
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
from colossalai.nn import CheckpointModule
|
||||||
from .utils import DummyDataGenerator
|
from .utils import DummyDataGenerator
|
||||||
from .registry import non_distributed_component_funcs
|
from .registry import non_distributed_component_funcs
|
||||||
|
|
||||||
|
@ -15,10 +16,10 @@ class SubNet(nn.Module):
|
||||||
return F.linear(x, weight, self.bias)
|
return F.linear(x, weight, self.bias)
|
||||||
|
|
||||||
|
|
||||||
class NestedNet(nn.Module):
|
class NestedNet(CheckpointModule):
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self, checkpoint=False) -> None:
|
||||||
super().__init__()
|
super().__init__(checkpoint)
|
||||||
self.fc1 = nn.Linear(5, 5)
|
self.fc1 = nn.Linear(5, 5)
|
||||||
self.sub_fc = SubNet(5)
|
self.sub_fc = SubNet(5)
|
||||||
self.fc2 = nn.Linear(5, 2)
|
self.fc2 = nn.Linear(5, 2)
|
||||||
|
@ -41,9 +42,15 @@ class DummyDataLoader(DummyDataGenerator):
|
||||||
|
|
||||||
@non_distributed_component_funcs.register(name='nested_model')
|
@non_distributed_component_funcs.register(name='nested_model')
|
||||||
def get_training_components():
|
def get_training_components():
|
||||||
model = NestedNet()
|
|
||||||
|
def model_builder(checkpoint):
|
||||||
|
return NestedNet(checkpoint)
|
||||||
|
|
||||||
trainloader = DummyDataLoader()
|
trainloader = DummyDataLoader()
|
||||||
testloader = DummyDataLoader()
|
testloader = DummyDataLoader()
|
||||||
optim = torch.optim.Adam(model.parameters(), lr=0.001)
|
|
||||||
|
def optim_builder(model):
|
||||||
|
return torch.optim.Adam(model.parameters(), lr=0.001)
|
||||||
|
|
||||||
criterion = torch.nn.CrossEntropyLoss()
|
criterion = torch.nn.CrossEntropyLoss()
|
||||||
return model, trainloader, testloader, optim, criterion
|
return model_builder, trainloader, testloader, optim_builder, criterion
|
||||||
|
|
|
@ -36,9 +36,15 @@ class DummyDataLoader(DummyDataGenerator):
|
||||||
|
|
||||||
@non_distributed_component_funcs.register(name='repeated_computed_layers')
|
@non_distributed_component_funcs.register(name='repeated_computed_layers')
|
||||||
def get_training_components():
|
def get_training_components():
|
||||||
model = NetWithRepeatedlyComputedLayers(checkpoint=True)
|
|
||||||
|
def model_builder(checkpoint=True):
|
||||||
|
return NetWithRepeatedlyComputedLayers(checkpoint)
|
||||||
|
|
||||||
trainloader = DummyDataLoader()
|
trainloader = DummyDataLoader()
|
||||||
testloader = DummyDataLoader()
|
testloader = DummyDataLoader()
|
||||||
optim = torch.optim.Adam(model.parameters(), lr=0.001)
|
|
||||||
|
def optim_builder(model):
|
||||||
|
return torch.optim.Adam(model.parameters(), lr=0.001)
|
||||||
|
|
||||||
criterion = torch.nn.CrossEntropyLoss()
|
criterion = torch.nn.CrossEntropyLoss()
|
||||||
return model, trainloader, testloader, optim, criterion
|
return model_builder, trainloader, testloader, optim_builder, criterion
|
||||||
|
|
|
@ -22,9 +22,15 @@ def get_cifar10_dataloader(train):
|
||||||
|
|
||||||
@non_distributed_component_funcs.register(name='resnet18')
|
@non_distributed_component_funcs.register(name='resnet18')
|
||||||
def get_resnet_training_components():
|
def get_resnet_training_components():
|
||||||
model = resnet18(num_classes=10)
|
|
||||||
|
def model_builder(checkpoint=False):
|
||||||
|
return resnet18(num_classes=10)
|
||||||
|
|
||||||
trainloader = get_cifar10_dataloader(train=True)
|
trainloader = get_cifar10_dataloader(train=True)
|
||||||
testloader = get_cifar10_dataloader(train=False)
|
testloader = get_cifar10_dataloader(train=False)
|
||||||
optim = torch.optim.Adam(model.parameters(), lr=0.001)
|
|
||||||
|
def optim_builder(model):
|
||||||
|
return torch.optim.Adam(model.parameters(), lr=0.001)
|
||||||
|
|
||||||
criterion = torch.nn.CrossEntropyLoss()
|
criterion = torch.nn.CrossEntropyLoss()
|
||||||
return model, trainloader, testloader, optim, criterion
|
return model_builder, trainloader, testloader, optim_builder, criterion
|
||||||
|
|
|
@ -16,10 +16,11 @@ CONFIG = dict(parallel=dict(pipeline=dict(size=1), tensor=dict(size=1, mode=None
|
||||||
|
|
||||||
def run_train():
|
def run_train():
|
||||||
for get_components_func in non_distributed_component_funcs:
|
for get_components_func in non_distributed_component_funcs:
|
||||||
model, train_dataloader, _, optimizer, criterion = get_components_func()
|
model_builder, train_dataloader, _, optimizer_builder, criterion = get_components_func()
|
||||||
|
|
||||||
|
model = model_builder(checkpoint=False)
|
||||||
engine, train_dataloader, *args = colossalai.initialize(model=model,
|
engine, train_dataloader, *args = colossalai.initialize(model=model,
|
||||||
optimizer=optimizer,
|
optimizer=optimizer_builder(model),
|
||||||
criterion=criterion,
|
criterion=criterion,
|
||||||
train_dataloader=train_dataloader)
|
train_dataloader=train_dataloader)
|
||||||
|
|
||||||
|
|
|
@ -9,22 +9,27 @@ import torch
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
from colossalai.zero.shard_utils.tensor_shard_strategy import TensorShardStrategy
|
from colossalai.zero.shard_utils.tensor_shard_strategy import TensorShardStrategy
|
||||||
from colossalai.zero.init_ctx import ZeroInitContext
|
from colossalai.zero.init_ctx import ZeroInitContext
|
||||||
from common import CONFIG, Net
|
from common import CONFIG
|
||||||
from colossalai.utils import free_port
|
from colossalai.utils import free_port
|
||||||
|
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||||
|
|
||||||
|
|
||||||
def run_dist(rank, world_size, port):
|
def run_dist(rank, world_size, port):
|
||||||
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||||
|
|
||||||
with ZeroInitContext(convert_fp16=True, convert_cuda=True, shard_strategy=TensorShardStrategy(), shard_param=True):
|
for get_components_func in non_distributed_component_funcs:
|
||||||
# Note Net(checkpoint=True).cuda() moving to cuda is useless
|
model_builder, _, _, _, _ = get_components_func()
|
||||||
model = Net(checkpoint=True)
|
with ZeroInitContext(convert_fp16=True,
|
||||||
|
convert_cuda=True,
|
||||||
|
shard_strategy=TensorShardStrategy(),
|
||||||
|
shard_param=True):
|
||||||
|
model = model_builder(checkpoint=True)
|
||||||
|
|
||||||
for param in model.parameters():
|
for param in model.parameters():
|
||||||
assert hasattr(param, 'ca_attr')
|
assert hasattr(param, 'ca_attr')
|
||||||
assert param.ca_attr.data.dtype == torch.half
|
assert param.ca_attr.data.dtype == torch.half
|
||||||
assert param.ca_attr._data_sharded_tensor.is_sharded
|
assert param.ca_attr._data_sharded_tensor.is_sharded
|
||||||
assert param.ca_attr.data.device.type == 'cuda'
|
assert param.ca_attr.data.device.type == 'cuda'
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
|
|
|
@ -46,6 +46,8 @@ def _run_shard_param_v2(rank, world_size, port):
|
||||||
sparam = ShardedParamV2(param=param, process_group=None)
|
sparam = ShardedParamV2(param=param, process_group=None)
|
||||||
|
|
||||||
allclose(sparam.data, param_ref.data)
|
allclose(sparam.data, param_ref.data)
|
||||||
|
|
||||||
|
sparam.remove_torch_payload()
|
||||||
assert (param.data.numel() == 1)
|
assert (param.data.numel() == 1)
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue