mirror of https://github.com/hpcaitech/ColossalAI
[zero] polish ZeroInitContext (#540)
parent
c11ff81b15
commit
1f90a3b129
|
@ -4,12 +4,11 @@ from typing import Optional
|
||||||
import torch
|
import torch
|
||||||
from colossalai.context.parallel_mode import ParallelMode
|
from colossalai.context.parallel_mode import ParallelMode
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
from colossalai.utils.memory_utils.memory_monitor import colo_cuda_memory_used
|
from colossalai.logging import get_dist_logger
|
||||||
from colossalai.zero.shard_utils import BaseShardStrategy
|
from colossalai.zero.shard_utils import BaseShardStrategy
|
||||||
from colossalai.zero.sharded_model._utils import cast_tensor_to_fp16
|
from colossalai.zero.sharded_model._utils import cast_tensor_to_fp16
|
||||||
from colossalai.zero.sharded_param import ShardedParamV2
|
from colossalai.zero.sharded_param import ShardedParamV2
|
||||||
from torch.distributed import ProcessGroup
|
from torch.distributed import ProcessGroup
|
||||||
from colossalai.logging import get_dist_logger, disable_existing_loggers
|
|
||||||
|
|
||||||
|
|
||||||
def _substitute_init_recursively(cls, func):
|
def _substitute_init_recursively(cls, func):
|
||||||
|
@ -107,20 +106,16 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
convert_fp16: bool,
|
|
||||||
target_device: torch.device,
|
target_device: torch.device,
|
||||||
shard_strategy: BaseShardStrategy,
|
shard_strategy: BaseShardStrategy,
|
||||||
shard_param: bool = False,
|
shard_param: bool = False,
|
||||||
shard_grad: bool = False,
|
|
||||||
rm_torch_payload_on_the_fly: bool = False,
|
rm_torch_payload_on_the_fly: bool = False,
|
||||||
model_numel_tensor: torch.Tensor = torch.zeros(1, dtype=torch.int),
|
model_numel_tensor: torch.Tensor = torch.zeros(1, dtype=torch.long),
|
||||||
dp_process_group: Optional[ProcessGroup] = None):
|
dp_process_group: Optional[ProcessGroup] = None):
|
||||||
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.convert_fp16 = convert_fp16
|
|
||||||
self.target_device = target_device
|
self.target_device = target_device
|
||||||
self.shard_param = shard_param
|
self.shard_param = shard_param
|
||||||
self.shard_grad = shard_grad
|
|
||||||
self.shard_strategy = shard_strategy
|
self.shard_strategy = shard_strategy
|
||||||
self.rm_torch_payload_on_the_fly = rm_torch_payload_on_the_fly
|
self.rm_torch_payload_on_the_fly = rm_torch_payload_on_the_fly
|
||||||
self.initialized_param_list = []
|
self.initialized_param_list = []
|
||||||
|
@ -157,8 +152,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
||||||
|
|
||||||
target_device = self.target_device
|
target_device = self.target_device
|
||||||
|
|
||||||
# convert to fp16 if necessary
|
# convert to fp16
|
||||||
if self.convert_fp16:
|
|
||||||
param.data = param.data.to(torch.half)
|
param.data = param.data.to(torch.half)
|
||||||
if param.grad is not None:
|
if param.grad is not None:
|
||||||
param.grad = param.grad.to(torch.half)
|
param.grad = param.grad.to(torch.half)
|
||||||
|
@ -179,5 +173,4 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
||||||
# We must cast them
|
# We must cast them
|
||||||
for buffer in module.buffers(recurse=False):
|
for buffer in module.buffers(recurse=False):
|
||||||
buffer.data = buffer.data.to(device=torch.cuda.current_device())
|
buffer.data = buffer.data.to(device=torch.cuda.current_device())
|
||||||
if self.convert_fp16:
|
|
||||||
buffer.data = cast_tensor_to_fp16(buffer.data)
|
buffer.data = cast_tensor_to_fp16(buffer.data)
|
||||||
|
|
|
@ -7,16 +7,17 @@ import colossalai
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
from colossalai.testing import parameterize
|
from colossalai.logging import get_dist_logger
|
||||||
|
from colossalai.testing import parameterize, rerun_on_exception
|
||||||
from colossalai.utils import free_port
|
from colossalai.utils import free_port
|
||||||
from colossalai.utils.cuda import get_current_device
|
from colossalai.utils.cuda import get_current_device
|
||||||
from colossalai.utils.memory_tracer.model_data_memtracer import col_model_data_mem_usage
|
from colossalai.utils.memory_tracer.model_data_memtracer import \
|
||||||
from colossalai.zero.init_ctx import ZeroInitContext
|
col_model_data_mem_usage
|
||||||
from colossalai.utils.memory_utils.memory_monitor import colo_cuda_memory_used
|
from colossalai.utils.memory_utils.memory_monitor import colo_cuda_memory_used
|
||||||
|
from colossalai.zero.init_ctx import ZeroInitContext
|
||||||
from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
|
from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
|
||||||
from colossalai.testing import rerun_on_exception
|
|
||||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||||
from colossalai.logging import get_dist_logger
|
|
||||||
from common import CONFIG
|
from common import CONFIG
|
||||||
|
|
||||||
|
|
||||||
|
@ -36,8 +37,7 @@ def run_model_test(init_device_type, shard_strategy_class):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
model_numel_tensor = torch.zeros(1, dtype=torch.int)
|
model_numel_tensor = torch.zeros(1, dtype=torch.int)
|
||||||
with ZeroInitContext(convert_fp16=True,
|
with ZeroInitContext(target_device=init_device,
|
||||||
target_device=init_device,
|
|
||||||
shard_strategy=shard_strategy_class(),
|
shard_strategy=shard_strategy_class(),
|
||||||
shard_param=True,
|
shard_param=True,
|
||||||
model_numel_tensor=model_numel_tensor,
|
model_numel_tensor=model_numel_tensor,
|
||||||
|
|
|
@ -7,14 +7,13 @@ import colossalai
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
from colossalai.testing import parameterize
|
from colossalai.testing import parameterize, rerun_on_exception
|
||||||
from colossalai.utils import free_port
|
from colossalai.utils import free_port
|
||||||
from colossalai.zero.init_ctx import ZeroInitContext
|
from colossalai.zero.init_ctx import ZeroInitContext
|
||||||
from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
|
from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
|
||||||
from colossalai.zero.sharded_model import ShardedModelV2
|
from colossalai.zero.sharded_model import ShardedModelV2
|
||||||
from colossalai.zero.sharded_model._utils import cast_tensor_to_fp16
|
from colossalai.zero.sharded_model._utils import cast_tensor_to_fp16
|
||||||
from colossalai.zero.sharded_model.utils import col_model_deepcopy
|
from colossalai.zero.sharded_model.utils import col_model_deepcopy
|
||||||
from colossalai.testing import rerun_on_exception
|
|
||||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
|
|
||||||
|
@ -32,8 +31,7 @@ def run_model_test(enable_autocast, shard_strategy_class):
|
||||||
|
|
||||||
rm_torch_payload_on_the_fly = False
|
rm_torch_payload_on_the_fly = False
|
||||||
|
|
||||||
with ZeroInitContext(convert_fp16=True,
|
with ZeroInitContext(target_device=torch.cuda.current_device(),
|
||||||
target_device=torch.cuda.current_device(),
|
|
||||||
shard_strategy=shard_strategy,
|
shard_strategy=shard_strategy,
|
||||||
shard_param=True,
|
shard_param=True,
|
||||||
rm_torch_payload_on_the_fly=rm_torch_payload_on_the_fly):
|
rm_torch_payload_on_the_fly=rm_torch_payload_on_the_fly):
|
||||||
|
|
|
@ -8,7 +8,7 @@ import torch.distributed as dist
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
from colossalai.amp import convert_to_apex_amp
|
from colossalai.amp import convert_to_apex_amp
|
||||||
from colossalai.nn.optimizer import CPUAdam
|
from colossalai.nn.optimizer import CPUAdam
|
||||||
from colossalai.testing import parameterize
|
from colossalai.testing import parameterize, rerun_on_exception
|
||||||
from colossalai.utils import free_port
|
from colossalai.utils import free_port
|
||||||
from colossalai.zero.init_ctx import ZeroInitContext
|
from colossalai.zero.init_ctx import ZeroInitContext
|
||||||
from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
|
from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
|
||||||
|
@ -16,7 +16,6 @@ from colossalai.zero.sharded_model import ShardedModelV2
|
||||||
from colossalai.zero.sharded_model.utils import col_model_deepcopy
|
from colossalai.zero.sharded_model.utils import col_model_deepcopy
|
||||||
from colossalai.zero.sharded_optim import ShardedOptimizerV2
|
from colossalai.zero.sharded_optim import ShardedOptimizerV2
|
||||||
from colossalai.zero.sharded_optim._utils import has_inf_or_nan
|
from colossalai.zero.sharded_optim._utils import has_inf_or_nan
|
||||||
from colossalai.testing import rerun_on_exception
|
|
||||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
|
|
||||||
|
@ -59,7 +58,6 @@ def _run_test_sharded_optim_v2(cpu_offload, shard_strategy_class, use_cpuadam, g
|
||||||
model_builder, train_dataloader, _, optimizer_class, criterion = get_components_func()
|
model_builder, train_dataloader, _, optimizer_class, criterion = get_components_func()
|
||||||
|
|
||||||
with ZeroInitContext(
|
with ZeroInitContext(
|
||||||
convert_fp16=True,
|
|
||||||
target_device=torch.device(f'cpu:0') if cpu_offload else torch.device(f'cuda:{get_current_device()}'),
|
target_device=torch.device(f'cpu:0') if cpu_offload else torch.device(f'cuda:{get_current_device()}'),
|
||||||
shard_strategy=shard_strategy,
|
shard_strategy=shard_strategy,
|
||||||
shard_param=True,
|
shard_param=True,
|
||||||
|
|
|
@ -10,11 +10,11 @@ import torch.distributed as dist
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
from colossalai.context.parallel_mode import ParallelMode
|
from colossalai.context.parallel_mode import ParallelMode
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
|
from colossalai.testing import rerun_on_exception
|
||||||
from colossalai.utils import free_port
|
from colossalai.utils import free_port
|
||||||
from colossalai.zero.init_ctx import ZeroInitContext
|
from colossalai.zero.init_ctx import ZeroInitContext
|
||||||
from colossalai.zero.shard_utils import TensorShardStrategy
|
from colossalai.zero.shard_utils import TensorShardStrategy
|
||||||
from torchvision.models import resnet50
|
from torchvision.models import resnet50
|
||||||
from colossalai.testing import rerun_on_exception
|
|
||||||
|
|
||||||
|
|
||||||
def run_dist(rank, world_size, port):
|
def run_dist(rank, world_size, port):
|
||||||
|
@ -30,8 +30,7 @@ def run_dist(rank, world_size, port):
|
||||||
port=port,
|
port=port,
|
||||||
backend='nccl')
|
backend='nccl')
|
||||||
|
|
||||||
with ZeroInitContext(convert_fp16=True,
|
with ZeroInitContext(target_device=torch.cuda.current_device(),
|
||||||
target_device=torch.cuda.current_device(),
|
|
||||||
shard_strategy=gpc.config.zero.model_config.shard_strategy,
|
shard_strategy=gpc.config.zero.model_config.shard_strategy,
|
||||||
shard_param=True):
|
shard_param=True):
|
||||||
model = resnet50()
|
model = resnet50()
|
||||||
|
|
|
@ -8,13 +8,12 @@ import colossalai
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
from colossalai.testing import parameterize
|
from colossalai.testing import parameterize, rerun_on_exception
|
||||||
from colossalai.utils import free_port
|
from colossalai.utils import free_port
|
||||||
from colossalai.zero.init_ctx import ZeroInitContext
|
from colossalai.zero.init_ctx import ZeroInitContext
|
||||||
from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
|
from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
|
||||||
from colossalai.zero.sharded_model import ShardedModelV2
|
from colossalai.zero.sharded_model import ShardedModelV2
|
||||||
from colossalai.zero.sharded_model.utils import col_model_deepcopy
|
from colossalai.zero.sharded_model.utils import col_model_deepcopy
|
||||||
from colossalai.testing import rerun_on_exception
|
|
||||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||||
|
|
||||||
from common import CONFIG
|
from common import CONFIG
|
||||||
|
@ -28,8 +27,7 @@ def run_zero_state_dict(shard_strategy_class):
|
||||||
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
||||||
model_builder, train_dataloader, test_dataloader, optimizer, criterion = get_components_func()
|
model_builder, train_dataloader, test_dataloader, optimizer, criterion = get_components_func()
|
||||||
|
|
||||||
with ZeroInitContext(convert_fp16=True,
|
with ZeroInitContext(target_device=torch.cuda.current_device(),
|
||||||
target_device=torch.cuda.current_device(),
|
|
||||||
shard_strategy=shard_strategy,
|
shard_strategy=shard_strategy,
|
||||||
shard_param=True,
|
shard_param=True,
|
||||||
rm_torch_payload_on_the_fly=False):
|
rm_torch_payload_on_the_fly=False):
|
||||||
|
|
|
@ -9,11 +9,11 @@ import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
|
from colossalai.testing import rerun_on_exception
|
||||||
from colossalai.utils import free_port
|
from colossalai.utils import free_port
|
||||||
from colossalai.zero.init_ctx import ZeroInitContext
|
from colossalai.zero.init_ctx import ZeroInitContext
|
||||||
from colossalai.zero.sharded_model.utils import col_model_deepcopy
|
from colossalai.zero.sharded_model.utils import col_model_deepcopy
|
||||||
from colossalai.zero.sharded_optim._utils import has_inf_or_nan
|
from colossalai.zero.sharded_optim._utils import has_inf_or_nan
|
||||||
from colossalai.testing import rerun_on_exception
|
|
||||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
|
|
||||||
|
@ -32,8 +32,7 @@ def run_dist(rank, world_size, port, parallel_config):
|
||||||
for model_name in test_models:
|
for model_name in test_models:
|
||||||
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
||||||
model_builder, train_dataloader, _, optimizer_class, criterion = get_components_func()
|
model_builder, train_dataloader, _, optimizer_class, criterion = get_components_func()
|
||||||
with ZeroInitContext(convert_fp16=hasattr(gpc.config, 'fp16'),
|
with ZeroInitContext(target_device=torch.cuda.current_device(),
|
||||||
target_device=torch.cuda.current_device(),
|
|
||||||
shard_strategy=gpc.config.zero.model_config.shard_strategy,
|
shard_strategy=gpc.config.zero.model_config.shard_strategy,
|
||||||
shard_param=True):
|
shard_param=True):
|
||||||
colo_model = model_builder(checkpoint=True)
|
colo_model = model_builder(checkpoint=True)
|
||||||
|
|
Loading…
Reference in New Issue