[zero] polish low level optimizer (#2473)

pull/2476/head
HELSON 2023-01-13 14:56:17 +08:00 committed by GitHub
parent 8b7495dd54
commit a5dc4253c6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 95 additions and 124 deletions

View File

@ -103,7 +103,11 @@ def split_half_float_double(tensor_list):
return buckets return buckets
def reduce_tensor_dp_group(tensor, dtype=None, dst_rank=None, pg: Optional[ProcessGroup] = None): def reduce_tensor_dp_group(tensor: torch.Tensor,
dtype: Optional[torch.dtype] = None,
dst_local_rank: Optional[int] = None,
dst_global_rank: Optional[int] = None,
group: Optional[dist.ProcessGroup] = None):
""" """
Reduce the tensor in the data parallel process group Reduce the tensor in the data parallel process group
@ -128,36 +132,22 @@ def reduce_tensor_dp_group(tensor, dtype=None, dst_rank=None, pg: Optional[Proce
else: else:
tensor_to_reduce = tensor tensor_to_reduce = tensor
if isinstance(pg, ProcessGroup): world_size = dist.get_world_size(group=group)
group = pg.dp_process_group()
world_size = pg.dp_world_size()
else:
world_size = gpc.get_world_size(ParallelMode.DATA)
group = gpc.get_group(ParallelMode.DATA)
tensor_to_reduce.div_(world_size) tensor_to_reduce.div_(world_size)
# if rank is None, all reduce will be used # if rank is None, all reduce will be used
# else, reduce is used # else, reduce is used
use_all_reduce = dst_rank is None use_all_reduce = dst_local_rank is None
if use_all_reduce: if use_all_reduce:
dist.all_reduce(tensor_to_reduce, group=group) dist.all_reduce(tensor_to_reduce, group=group)
else: else:
if pg is not None: dist.reduce(tensor=tensor_to_reduce, dst=dst_global_rank, group=group)
ranks_in_group = pg.dp_rank_list()
else:
ranks_in_group = gpc.get_ranks_in_group(ParallelMode.DATA)
global_rank = ranks_in_group[dst_rank]
dist.reduce(tensor=tensor_to_reduce, dst=global_rank, group=group)
# recover the original dtype # recover the original dtype
if tensor.dtype != dtype and tensor is not tensor_to_reduce: if tensor.dtype != dtype and tensor is not tensor_to_reduce:
if pg is not None: local_rank = dist.get_rank(group=group)
local_rank = pg.dp_local_rank() if use_all_reduce or dst_local_rank == local_rank:
else:
local_rank = gpc.get_local_rank(ParallelMode.DATA)
if use_all_reduce or dst_rank == local_rank:
tensor.copy_(tensor_to_reduce) tensor.copy_(tensor_to_reduce)
return tensor return tensor

View File

@ -1,19 +1,12 @@
from typing import Optional import torch.distributed as dist
from torch.distributed import ProcessGroup
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.tensor import ProcessGroup
class BaseStore: class BaseStore:
def __init__(self, pg: Optional[ProcessGroup] = None): def __init__(self, torch_pg: ProcessGroup):
if isinstance(pg, ProcessGroup): self._world_size = dist.get_world_size(group=torch_pg)
self._world_size = pg.dp_world_size() self._local_rank = dist.get_rank(group=torch_pg)
self._local_rank = pg.dp_local_rank()
else:
self._world_size = gpc.get_world_size(ParallelMode.DATA)
self._local_rank = gpc.get_local_rank(ParallelMode.DATA)
@property @property
def world_size(self): def world_size(self):

View File

@ -1,14 +1,12 @@
from typing import Optional from torch.distributed import ProcessGroup
from colossalai.tensor import ProcessGroup
from .base_store import BaseStore from .base_store import BaseStore
class BucketStore(BaseStore): class BucketStore(BaseStore):
def __init__(self, pg: Optional[ProcessGroup] = None): def __init__(self, torch_pg: ProcessGroup):
super().__init__(pg) super().__init__(torch_pg)
self._grads = dict() self._grads = dict()
self._params = dict() self._params = dict()
self._num_elements_in_bucket = dict() self._num_elements_in_bucket = dict()

View File

@ -1,16 +1,15 @@
from typing import List, Optional from typing import List
from torch import Tensor from torch import Tensor
from torch.distributed import ProcessGroup
from colossalai.tensor import ProcessGroup
from .base_store import BaseStore from .base_store import BaseStore
class ParameterStore(BaseStore): class ParameterStore(BaseStore):
def __init__(self, pg: Optional[ProcessGroup] = None): def __init__(self, torch_pg: ProcessGroup):
super().__init__(pg) super().__init__(torch_pg)
# param partitioning data structures # param partitioning data structures
self._fp16_param_to_rank = dict() self._fp16_param_to_rank = dict()
self._rank_groupid_to_fp16_param_list = dict() self._rank_groupid_to_fp16_param_list = dict()

View File

@ -10,7 +10,7 @@ from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import ColossalaiOptimizer from colossalai.nn.optimizer import ColossalaiOptimizer
from colossalai.tensor import ProcessGroup from colossalai.tensor import ColoParameter, ProcessGroup
from colossalai.utils.cuda import get_current_device from colossalai.utils.cuda import get_current_device
from ._utils import ( from ._utils import (
@ -34,32 +34,21 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
def __init__( def __init__(
self, self,
optimizer: Optimizer, optimizer: Optimizer,
pg: Optional[ProcessGroup] = None, initial_scale: int = 2**16, # grad scaler config
# grad scaler config min_scale: int = 1,
initial_scale=2**16, growth_factor: float = 2.,
min_scale=1, backoff_factor: float = .5,
growth_factor=2, growth_interval: int = 2000,
backoff_factor=0.5, hysteresis: int = 2,
growth_interval=2000,
hysteresis=2,
max_scale: int = 2**24, max_scale: int = 2**24,
clip_grad_norm: float = 0.0, # grad clipping
# grad clipping verbose: bool = False,
clip_grad_norm=0.0, reduce_bucket_size: int = 1024 * 1024, # communication
verbose=False, communication_dtype: Optional[torch.dtype] = None,
overlap_communication: bool = False,
# communication partition_grad: bool = False, # stage 2
reduce_bucket_size=1024 * 1024, cpu_offload: bool = False, # cpu offload
communication_dtype=None, forced_dtype: Optional[torch.dtype] = None):
overlap_communication=False,
# stage 2
partition_grad=False,
# cpu offload
cpu_offload=False,
# forced dtype
forced_dtype=None):
# TODO: add support for # TODO: add support for
# 1. fp16 master weights # 1. fp16 master weights
@ -76,16 +65,16 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
self._cpu_offload = cpu_offload self._cpu_offload = cpu_offload
self._pg = pg colo_pg = self._search_colo_process_group()
if isinstance(pg, ProcessGroup): if isinstance(colo_pg, ProcessGroup):
self._local_rank = pg.dp_local_rank() self._local_rank = colo_pg.dp_local_rank()
self._world_size = pg.dp_world_size() self._world_size = colo_pg.dp_world_size()
self._dp_group = pg.dp_process_group() self._dp_global_ranks = colo_pg.get_ranks_in_dp()
if pg.tp_world_size() > 1: self._dp_torch_group = colo_pg.dp_process_group()
self._mp_group = pg.tp_process_group() self._mp_torch_group = None
else: if colo_pg.tp_world_size() > 1:
self._mp_group = None self._mp_torch_group = colo_pg.tp_process_group()
elif pg is None: elif colo_pg is None:
dp_parallel_mode = ParallelMode.DATA dp_parallel_mode = ParallelMode.DATA
mp_parallel_mode = ParallelMode.MODEL mp_parallel_mode = ParallelMode.MODEL
@ -93,14 +82,13 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
self._mp_parallel_mode = mp_parallel_mode self._mp_parallel_mode = mp_parallel_mode
self._local_rank = gpc.get_local_rank(dp_parallel_mode) self._local_rank = gpc.get_local_rank(dp_parallel_mode)
self._world_size = gpc.get_world_size(dp_parallel_mode) self._world_size = gpc.get_world_size(dp_parallel_mode)
self._dp_global_ranks = gpc.get_ranks_in_group(dp_parallel_mode)
self._dp_group = gpc.get_group(dp_parallel_mode) self._dp_torch_group = gpc.get_group(dp_parallel_mode)
self._mp_torch_group = None
if gpc.is_initialized(mp_parallel_mode) and gpc.get_world_size(mp_parallel_mode) > 1: if gpc.is_initialized(mp_parallel_mode) and gpc.get_world_size(mp_parallel_mode) > 1:
self._mp_group = gpc.get_group(mp_parallel_mode) self._mp_torch_group = gpc.get_group(mp_parallel_mode)
else: else:
self._mp_group = None raise NotImplementedError
else:
raise TypeError(f"pg should be None or a ProcesGroup")
# fp16 and fp32 params for mixed precision training # fp16 and fp32 params for mixed precision training
self._fp16_param_groups = dict() self._fp16_param_groups = dict()
self._fp32_flat_param_groups_of_current_rank = dict() self._fp32_flat_param_groups_of_current_rank = dict()
@ -136,14 +124,9 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
# ParameterStore will manage the tensor buffers used for zero # ParameterStore will manage the tensor buffers used for zero
# it will not manage the tensors used by mixed precision training # it will not manage the tensors used by mixed precision training
if self._pg is not None: self._param_store = ParameterStore(self._dp_torch_group)
self._param_store = ParameterStore(self._pg) self._grad_store = GradientStore(self._dp_torch_group)
self._grad_store = GradientStore(self._pg) self._bucket_store = BucketStore(self._dp_torch_group)
self._bucket_store = BucketStore(self._pg)
else:
self._param_store = ParameterStore(self._dp_parallel_mode)
self._grad_store = GradientStore(self._dp_parallel_mode)
self._bucket_store = BucketStore(self._dp_parallel_mode)
# iterate over the param group in the optimizer # iterate over the param group in the optimizer
# partition these param groups for data parallel training # partition these param groups for data parallel training
@ -224,6 +207,30 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
def num_param_groups(self): def num_param_groups(self):
return len(self._fp16_param_groups) return len(self._fp16_param_groups)
def _sanity_checks(self):
assert torch.cuda.is_available(), 'CUDA is required'
for param_group in self.optim.param_groups:
group_params = param_group['params']
for param in group_params:
assert param.dtype == self._dtype, \
f"Parameters are expected to have the same dtype `{self._dtype}`, but got `{param.dtype}`"
def _search_colo_process_group(self):
colo_flag = False
colo_pg = None
for param_group in self.optim.param_groups:
group_params = param_group['params']
for param in group_params:
if isinstance(param, ColoParameter):
colo_flag = True
if colo_pg is None:
colo_pg = param.get_process_group()
else:
assert colo_pg == param.get_process_group(), "All parameters should be in a same process group"
elif colo_flag:
raise RuntimeError("All parameters should be ColoParameter if you use ColoParameter.")
return colo_pg
def _partition_param_list(self, param_list): def _partition_param_list(self, param_list):
params_per_rank = [[] for _ in range(self._world_size)] params_per_rank = [[] for _ in range(self._world_size)]
numel_per_rank = [0 for _ in range(self._world_size)] numel_per_rank = [0 for _ in range(self._world_size)]
@ -241,14 +248,6 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
self._logger.info(f'Number of elements on ranks: {numel_per_rank}', ranks=[0]) self._logger.info(f'Number of elements on ranks: {numel_per_rank}', ranks=[0])
return params_per_rank return params_per_rank
def _sanity_checks(self):
assert torch.cuda.is_available(), 'CUDA is required'
for param_group in self.optim.param_groups:
group_params = param_group['params']
for param in group_params:
assert param.dtype == self._dtype, \
f"Parameters are expected to have the same dtype `{self._dtype}`, but got `{param.dtype}`"
########################################################### ###########################################################
# Backward Reduction Hook # Backward Reduction Hook
########################################################### ###########################################################
@ -384,10 +383,14 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
with torch.cuda.stream(stream): with torch.cuda.stream(stream):
flat = bucket.flatten() flat = bucket.flatten()
reduce_global_rank = None
if reduce_rank is not None:
reduce_global_rank = self._dp_global_ranks[reduce_rank]
reduced_flat = reduce_tensor_dp_group(tensor=flat, reduced_flat = reduce_tensor_dp_group(tensor=flat,
dtype=self._communication_dtype, dtype=self._communication_dtype,
dst_rank=reduce_rank, dst_local_rank=reduce_rank,
pg=self._pg) dst_global_rank=reduce_global_rank,
group=self._dp_torch_group)
# update the reduced tensor # update the reduced tensor
if reduce_rank is None or reduce_rank == self._local_rank: if reduce_rank is None or reduce_rank == self._local_rank:
@ -456,8 +459,8 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
norm_group = compute_norm(gradients=self._grad_store._averaged_gradients[group_id], norm_group = compute_norm(gradients=self._grad_store._averaged_gradients[group_id],
params=self._param_store.get_fp16_params_by_rank_group(group_id=group_id, params=self._param_store.get_fp16_params_by_rank_group(group_id=group_id,
rank=self._local_rank), rank=self._local_rank),
dp_group=self._dp_group, dp_group=self._dp_torch_group,
mp_group=self._mp_group) mp_group=self._mp_torch_group)
norm_groups.append(norm_group) norm_groups.append(norm_group)
# create flat gradient for the flat fp32 params # create flat gradient for the flat fp32 params
@ -497,7 +500,7 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
for group_id in range(self.num_param_groups): for group_id in range(self.num_param_groups):
for rank in range(self._world_size): for rank in range(self._world_size):
fp16_param = self._param_store.get_flat_fp16_param_by_rank_group(rank=rank, group_id=group_id) fp16_param = self._param_store.get_flat_fp16_param_by_rank_group(rank=rank, group_id=group_id)
handle = dist.broadcast(fp16_param, src=rank, group=self._dp_group, async_op=True) handle = dist.broadcast(fp16_param, src=rank, group=self._dp_torch_group, async_op=True)
handles.append(handle) handles.append(handle)
for handle in handles: for handle in handles:
@ -519,11 +522,11 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
break break
# all-reduce across dp group # all-reduce across dp group
dist.all_reduce(self._found_overflow, op=dist.ReduceOp.MAX, group=self._dp_group) dist.all_reduce(self._found_overflow, op=dist.ReduceOp.MAX, group=self._dp_torch_group)
# all-reduce over model parallel group # all-reduce over model parallel group
if self._mp_group: if self._mp_torch_group:
dist.all_reduce(self._found_overflow, op=dist.ReduceOp.MAX, group=self._mp_group) dist.all_reduce(self._found_overflow, op=dist.ReduceOp.MAX, group=self._mp_torch_group)
if self._found_overflow.item() > 0: if self._found_overflow.item() > 0:
return True return True

View File

@ -35,18 +35,15 @@ def exam_zero_1_2_grad_acc():
# create model # create model
zero1_model = TestModel().cuda() zero1_model = TestModel().cuda()
zero2_model = copy.deepcopy(zero1_model) zero2_model = copy.deepcopy(zero1_model)
pg = ProcessGroup()
# create optimizer # create optimizer
zero1_optimizer = torch.optim.Adam(zero1_model.parameters(), lr=1) zero1_optimizer = torch.optim.Adam(zero1_model.parameters(), lr=1)
zero2_optimizer = torch.optim.Adam(zero2_model.parameters(), lr=1) zero2_optimizer = torch.optim.Adam(zero2_model.parameters(), lr=1)
zero1_optimizer = LowLevelZeroOptimizer(zero1_optimizer, zero1_optimizer = LowLevelZeroOptimizer(zero1_optimizer,
pg=pg,
overlap_communication=True, overlap_communication=True,
initial_scale=32, initial_scale=32,
clip_grad_norm=1.0, clip_grad_norm=1.0,
verbose=True) verbose=True)
zero2_optimizer = LowLevelZeroOptimizer(zero2_optimizer, zero2_optimizer = LowLevelZeroOptimizer(zero2_optimizer,
pg=pg,
overlap_communication=True, overlap_communication=True,
partition_grad=True, partition_grad=True,
initial_scale=32, initial_scale=32,
@ -86,7 +83,7 @@ def exam_zero_1_2_grad_acc():
assert torch.equal(z1p.data, z2p.data) assert torch.equal(z1p.data, z2p.data)
def exam_zero_1_grad_acc(use_pg=True): def exam_zero_1_grad_acc():
local_rank = torch.distributed.get_rank() local_rank = torch.distributed.get_rank()
grad_scale = 32 grad_scale = 32
seed_all(2008) seed_all(2008)
@ -105,9 +102,7 @@ def exam_zero_1_grad_acc(use_pg=True):
# we only test stage 1 here # we only test stage 1 here
# in `check_sharded_param_consistency.py`, we will test whether # in `check_sharded_param_consistency.py`, we will test whether
# level 1 and 2 will produce exactly the same results # level 1 and 2 will produce exactly the same results
pg = ProcessGroup() if use_pg else None #ProcessGroup()
zero_optimizer = LowLevelZeroOptimizer(zero_optimizer, zero_optimizer = LowLevelZeroOptimizer(zero_optimizer,
pg=pg,
overlap_communication=False, overlap_communication=False,
initial_scale=grad_scale, initial_scale=grad_scale,
reduce_bucket_size=262144, reduce_bucket_size=262144,
@ -158,9 +153,8 @@ def exam_zero_1_grad_acc(use_pg=True):
def run_dist(rank, world_size, port): def run_dist(rank, world_size, port):
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost') colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost')
exam_zero_1_grad_acc(True) exam_zero_1_grad_acc()
exam_zero_1_grad_acc(False) exam_zero_1_2_grad_acc()
# exam_zero_1_2_grad_acc()
@pytest.mark.dist @pytest.mark.dist

View File

@ -9,7 +9,6 @@ from torch.nn.parallel import DistributedDataParallel as DDP
from torch.testing import assert_close from torch.testing import assert_close
import colossalai import colossalai
from colossalai.tensor import ProcessGroup
from colossalai.testing.random import seed_all from colossalai.testing.random import seed_all
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.zero import LowLevelZeroOptimizer from colossalai.zero import LowLevelZeroOptimizer
@ -59,17 +58,14 @@ def exam_zero_1_2():
zero1_model = TestModel().cuda() zero1_model = TestModel().cuda()
zero2_model = copy.deepcopy(zero1_model) zero2_model = copy.deepcopy(zero1_model)
pg = ProcessGroup()
# create optimizer # create optimizer
zero1_optimizer = torch.optim.Adam(zero1_model.parameters(), lr=1) zero1_optimizer = torch.optim.Adam(zero1_model.parameters(), lr=1)
zero2_optimizer = torch.optim.Adam(zero2_model.parameters(), lr=1) zero2_optimizer = torch.optim.Adam(zero2_model.parameters(), lr=1)
zero1_optimizer = LowLevelZeroOptimizer(zero1_optimizer, zero1_optimizer = LowLevelZeroOptimizer(zero1_optimizer,
pg=pg,
overlap_communication=True, overlap_communication=True,
initial_scale=128, initial_scale=128,
verbose=True) verbose=True)
zero2_optimizer = LowLevelZeroOptimizer(zero2_optimizer, zero2_optimizer = LowLevelZeroOptimizer(zero2_optimizer,
pg=pg,
overlap_communication=True, overlap_communication=True,
partition_grad=True, partition_grad=True,
initial_scale=128) initial_scale=128)
@ -119,7 +115,7 @@ def exam_zero_1_torch_ddp():
torch_model = copy.deepcopy(zero_model) torch_model = copy.deepcopy(zero_model)
zero_model = zero_model.cuda().half() zero_model = zero_model.cuda().half()
# torch_model = DDP(torch_model.cuda(), bucket_cap_mb=0) torch_model = DDP(torch_model.cuda(), bucket_cap_mb=0)
torch_model = torch_model.cuda() torch_model = torch_model.cuda()
# for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()): # for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()):
@ -131,9 +127,7 @@ def exam_zero_1_torch_ddp():
# we only test stage 1 here # we only test stage 1 here
# in `check_sharded_param_consistency.py`, we will test whether # in `check_sharded_param_consistency.py`, we will test whether
# level 1 and 2 will produce exactly the same results # level 1 and 2 will produce exactly the same results
pg = ProcessGroup()
zero_optimizer = LowLevelZeroOptimizer(zero_optimizer, zero_optimizer = LowLevelZeroOptimizer(zero_optimizer,
pg=pg,
overlap_communication=True, overlap_communication=True,
initial_scale=1, initial_scale=1,
reduce_bucket_size=262144) reduce_bucket_size=262144)