[zero] low level optim supports ProcessGroup (#2464)

pull/2471/head
Jiarui Fang 2023-01-13 10:05:58 +08:00 committed by GitHub
parent e6943e2d11
commit 867c8c2d3a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 106 additions and 52 deletions

View File

@ -1,4 +1,5 @@
import math import math
from typing import Optional
import torch import torch
import torch.distributed as dist import torch.distributed as dist
@ -7,6 +8,7 @@ from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from colossalai.context import ParallelMode from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.tensor import ProcessGroup
from colossalai.utils import is_model_parallel_parameter from colossalai.utils import is_model_parallel_parameter
@ -101,7 +103,7 @@ def split_half_float_double(tensor_list):
return buckets return buckets
def reduce_tensor(tensor, dtype=None, dst_rank=None, parallel_mode=ParallelMode.DATA): def reduce_tensor_dp_group(tensor, dtype=None, dst_rank=None, pg: Optional[ProcessGroup] = None):
""" """
Reduce the tensor in the data parallel process group Reduce the tensor in the data parallel process group
@ -114,7 +116,7 @@ def reduce_tensor(tensor, dtype=None, dst_rank=None, parallel_mode=ParallelMode.
:type tensor: torch.Tensor :type tensor: torch.Tensor
:type dtype: torch.dtype, optional :type dtype: torch.dtype, optional
:type dst_rank: int, optional :type dst_rank: int, optional
:type parallel_mode: ParallelMode, optional :type pg: ProcessGroup, optional
""" """
# use the original dtype # use the original dtype
if dtype is None: if dtype is None:
@ -126,8 +128,13 @@ def reduce_tensor(tensor, dtype=None, dst_rank=None, parallel_mode=ParallelMode.
else: else:
tensor_to_reduce = tensor tensor_to_reduce = tensor
world_size = gpc.get_world_size(parallel_mode) if isinstance(pg, ProcessGroup):
group = gpc.get_group(parallel_mode) group = pg.dp_process_group()
world_size = pg.dp_world_size()
else:
world_size = gpc.get_world_size(ParallelMode.DATA)
group = gpc.get_group(ParallelMode.DATA)
tensor_to_reduce.div_(world_size) tensor_to_reduce.div_(world_size)
# if rank is None, all reduce will be used # if rank is None, all reduce will be used
@ -137,13 +144,19 @@ def reduce_tensor(tensor, dtype=None, dst_rank=None, parallel_mode=ParallelMode.
if use_all_reduce: if use_all_reduce:
dist.all_reduce(tensor_to_reduce, group=group) dist.all_reduce(tensor_to_reduce, group=group)
else: else:
ranks_in_group = gpc.get_ranks_in_group(parallel_mode) if pg is not None:
ranks_in_group = pg.dp_rank_list()
else:
ranks_in_group = gpc.get_ranks_in_group(ParallelMode.DATA)
global_rank = ranks_in_group[dst_rank] global_rank = ranks_in_group[dst_rank]
dist.reduce(tensor=tensor_to_reduce, dst=global_rank, group=group) dist.reduce(tensor=tensor_to_reduce, dst=global_rank, group=group)
# recover the original dtype # recover the original dtype
if tensor.dtype != dtype and tensor is not tensor_to_reduce: if tensor.dtype != dtype and tensor is not tensor_to_reduce:
local_rank = gpc.get_local_rank(parallel_mode) if pg is not None:
local_rank = pg.dp_local_rank()
else:
local_rank = gpc.get_local_rank(ParallelMode.DATA)
if use_all_reduce or dst_rank == local_rank: if use_all_reduce or dst_rank == local_rank:
tensor.copy_(tensor_to_reduce) tensor.copy_(tensor_to_reduce)

View File

@ -1,12 +1,19 @@
from typing import Optional
from colossalai.context import ParallelMode from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.tensor import ProcessGroup
class BaseStore: class BaseStore:
def __init__(self, dp_parallel_mode=ParallelMode.DATA): def __init__(self, pg: Optional[ProcessGroup] = None):
self._world_size = gpc.get_world_size(dp_parallel_mode) if isinstance(pg, ProcessGroup):
self._local_rank = gpc.get_local_rank(dp_parallel_mode) self._world_size = pg.dp_world_size()
self._local_rank = pg.dp_local_rank()
else:
self._world_size = gpc.get_world_size(ParallelMode.DATA)
self._local_rank = gpc.get_local_rank(ParallelMode.DATA)
@property @property
def world_size(self): def world_size(self):

View File

@ -1,13 +1,14 @@
from colossalai.context import ParallelMode from typing import Optional
from colossalai.core import global_context as gpc
from colossalai.tensor import ProcessGroup
from .base_store import BaseStore from .base_store import BaseStore
class BucketStore(BaseStore): class BucketStore(BaseStore):
def __init__(self, dp_parallel_mode): def __init__(self, pg: Optional[ProcessGroup] = None):
super().__init__(dp_parallel_mode) super().__init__(pg)
self._grads = dict() self._grads = dict()
self._params = dict() self._params = dict()
self._num_elements_in_bucket = dict() self._num_elements_in_bucket = dict()

View File

@ -1,14 +1,16 @@
from typing import List from typing import List, Optional
from torch import Tensor from torch import Tensor
from colossalai.tensor import ProcessGroup
from .base_store import BaseStore from .base_store import BaseStore
class ParameterStore(BaseStore): class ParameterStore(BaseStore):
def __init__(self, dp_paralle_mode): def __init__(self, pg: Optional[ProcessGroup] = None):
super().__init__(dp_paralle_mode) super().__init__(pg)
# param partitioning data structures # param partitioning data structures
self._fp16_param_to_rank = dict() self._fp16_param_to_rank = dict()
self._rank_groupid_to_fp16_param_list = dict() self._rank_groupid_to_fp16_param_list = dict()

View File

@ -1,5 +1,5 @@
from functools import partial from functools import partial
from itertools import groupby from typing import Optional
import torch import torch
import torch.distributed as dist import torch.distributed as dist
@ -10,6 +10,7 @@ from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import ColossalaiOptimizer from colossalai.nn.optimizer import ColossalaiOptimizer
from colossalai.tensor import ProcessGroup
from colossalai.utils.cuda import get_current_device from colossalai.utils.cuda import get_current_device
from ._utils import ( from ._utils import (
@ -18,7 +19,7 @@ from ._utils import (
flatten, flatten,
get_grad_accumulate_object, get_grad_accumulate_object,
has_inf_or_nan, has_inf_or_nan,
reduce_tensor, reduce_tensor_dp_group,
release_param_grad, release_param_grad,
split_half_float_double, split_half_float_double,
sync_param, sync_param,
@ -33,7 +34,7 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
def __init__( def __init__(
self, self,
optimizer: Optimizer, optimizer: Optimizer,
pg: Optional[ProcessGroup] = None,
# grad scaler config # grad scaler config
initial_scale=2**16, initial_scale=2**16,
min_scale=1, min_scale=1,
@ -54,9 +55,6 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
# stage 2 # stage 2
partition_grad=False, partition_grad=False,
dp_parallel_mode=ParallelMode.DATA,
mp_parallel_mode=ParallelMode.MODEL,
# cpu offload # cpu offload
cpu_offload=False, cpu_offload=False,
@ -76,10 +74,21 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
# stage 2 # stage 2
self._partition_grads = partition_grad self._partition_grads = partition_grad
# cpu_offload
self._cpu_offload = cpu_offload self._cpu_offload = cpu_offload
# get process groups self._pg = pg
if isinstance(pg, ProcessGroup):
self._local_rank = pg.dp_local_rank()
self._world_size = pg.dp_world_size()
self._dp_group = pg.dp_process_group()
if pg.tp_world_size() > 1:
self._mp_group = pg.tp_process_group()
else:
self._mp_group = None
elif pg is None:
dp_parallel_mode = ParallelMode.DATA
mp_parallel_mode = ParallelMode.MODEL
self._dp_parallel_mode = dp_parallel_mode self._dp_parallel_mode = dp_parallel_mode
self._mp_parallel_mode = mp_parallel_mode self._mp_parallel_mode = mp_parallel_mode
self._local_rank = gpc.get_local_rank(dp_parallel_mode) self._local_rank = gpc.get_local_rank(dp_parallel_mode)
@ -90,7 +99,8 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
self._mp_group = gpc.get_group(mp_parallel_mode) self._mp_group = gpc.get_group(mp_parallel_mode)
else: else:
self._mp_group = None self._mp_group = None
else:
raise TypeError(f"pg should be None or a ProcesGroup")
# fp16 and fp32 params for mixed precision training # fp16 and fp32 params for mixed precision training
self._fp16_param_groups = dict() self._fp16_param_groups = dict()
self._fp32_flat_param_groups_of_current_rank = dict() self._fp32_flat_param_groups_of_current_rank = dict()
@ -126,6 +136,11 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
# ParameterStore will manage the tensor buffers used for zero # ParameterStore will manage the tensor buffers used for zero
# it will not manage the tensors used by mixed precision training # it will not manage the tensors used by mixed precision training
if self._pg is not None:
self._param_store = ParameterStore(self._pg)
self._grad_store = GradientStore(self._pg)
self._bucket_store = BucketStore(self._pg)
else:
self._param_store = ParameterStore(self._dp_parallel_mode) self._param_store = ParameterStore(self._dp_parallel_mode)
self._grad_store = GradientStore(self._dp_parallel_mode) self._grad_store = GradientStore(self._dp_parallel_mode)
self._bucket_store = BucketStore(self._dp_parallel_mode) self._bucket_store = BucketStore(self._dp_parallel_mode)
@ -223,9 +238,7 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
numel_per_rank[rank_to_go] += param.numel() numel_per_rank[rank_to_go] += param.numel()
if self._verbose: if self._verbose:
self._logger.info(f'Number of elements on ranks: {numel_per_rank}', self._logger.info(f'Number of elements on ranks: {numel_per_rank}', ranks=[0])
ranks=[0],
parallel_mode=self._dp_parallel_mode)
return params_per_rank return params_per_rank
def _sanity_checks(self): def _sanity_checks(self):
@ -371,10 +384,10 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
with torch.cuda.stream(stream): with torch.cuda.stream(stream):
flat = bucket.flatten() flat = bucket.flatten()
reduced_flat = reduce_tensor(tensor=flat, reduced_flat = reduce_tensor_dp_group(tensor=flat,
dtype=self._communication_dtype, dtype=self._communication_dtype,
dst_rank=reduce_rank, dst_rank=reduce_rank,
parallel_mode=self._dp_parallel_mode) pg=self._pg)
# update the reduced tensor # update the reduced tensor
if reduce_rank is None or reduce_rank == self._local_rank: if reduce_rank is None or reduce_rank == self._local_rank:

View File

@ -290,14 +290,19 @@ def main():
from torch.distributed.optim import ZeroRedundancyOptimizer from torch.distributed.optim import ZeroRedundancyOptimizer
optimizer = ZeroRedundancyOptimizer(model.parameters(), optimizer_class=torch.optim.Adam, lr=0.01) optimizer = ZeroRedundancyOptimizer(model.parameters(), optimizer_class=torch.optim.Adam, lr=0.01)
elif args.distplan.startswith("zero"): elif args.distplan.startswith("zero"):
pg = ProcessGroup()
model = model.half() model = model.half()
partition_flag = args.distplan == "zero2" partition_flag = (args.distplan == "zero2")
optimizer = torch.optim.Adam(model.parameters(), lr=0.01) optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
optimizer = LowLevelZeroOptimizer(optimizer,
optimizer = LowLevelZeroOptimizer(
optimizer,
pg=pg,
reduce_bucket_size=12 * 1024 * 1024, reduce_bucket_size=12 * 1024 * 1024,
overlap_communication=True, overlap_communication=True,
partition_grad=partition_flag, partition_grad=partition_flag,
verbose=True) verbose=True,
)
# model is shared after TP # model is shared after TP
numel = get_model_size(model) numel = get_model_size(model)

View File

@ -9,6 +9,7 @@ from torch.nn.parallel import DistributedDataParallel as DDP
from torch.testing import assert_close from torch.testing import assert_close
import colossalai import colossalai
from colossalai.tensor import ProcessGroup
from colossalai.testing.random import seed_all from colossalai.testing.random import seed_all
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.zero import LowLevelZeroOptimizer from colossalai.zero import LowLevelZeroOptimizer
@ -34,16 +35,18 @@ def exam_zero_1_2_grad_acc():
# create model # create model
zero1_model = TestModel().cuda() zero1_model = TestModel().cuda()
zero2_model = copy.deepcopy(zero1_model) zero2_model = copy.deepcopy(zero1_model)
pg = ProcessGroup()
# create optimizer # create optimizer
zero1_optimizer = torch.optim.Adam(zero1_model.parameters(), lr=1) zero1_optimizer = torch.optim.Adam(zero1_model.parameters(), lr=1)
zero2_optimizer = torch.optim.Adam(zero2_model.parameters(), lr=1) zero2_optimizer = torch.optim.Adam(zero2_model.parameters(), lr=1)
zero1_optimizer = LowLevelZeroOptimizer(zero1_optimizer, zero1_optimizer = LowLevelZeroOptimizer(zero1_optimizer,
pg=pg,
overlap_communication=True, overlap_communication=True,
initial_scale=32, initial_scale=32,
clip_grad_norm=1.0, clip_grad_norm=1.0,
verbose=True) verbose=True)
zero2_optimizer = LowLevelZeroOptimizer(zero2_optimizer, zero2_optimizer = LowLevelZeroOptimizer(zero2_optimizer,
pg=pg,
overlap_communication=True, overlap_communication=True,
partition_grad=True, partition_grad=True,
initial_scale=32, initial_scale=32,
@ -83,7 +86,7 @@ def exam_zero_1_2_grad_acc():
assert torch.equal(z1p.data, z2p.data) assert torch.equal(z1p.data, z2p.data)
def exam_zero_1_grad_acc(): def exam_zero_1_grad_acc(use_pg=True):
local_rank = torch.distributed.get_rank() local_rank = torch.distributed.get_rank()
grad_scale = 32 grad_scale = 32
seed_all(2008) seed_all(2008)
@ -92,6 +95,7 @@ def exam_zero_1_grad_acc():
zero_model = TestModel() zero_model = TestModel()
torch_model = copy.deepcopy(zero_model) torch_model = copy.deepcopy(zero_model)
seed_all(2008)
zero_model = zero_model.cuda() zero_model = zero_model.cuda()
torch_model = DDP(torch_model.cuda(), bucket_cap_mb=0) torch_model = DDP(torch_model.cuda(), bucket_cap_mb=0)
@ -101,7 +105,9 @@ def exam_zero_1_grad_acc():
# we only test stage 1 here # we only test stage 1 here
# in `check_sharded_param_consistency.py`, we will test whether # in `check_sharded_param_consistency.py`, we will test whether
# level 1 and 2 will produce exactly the same results # level 1 and 2 will produce exactly the same results
pg = ProcessGroup() if use_pg else None #ProcessGroup()
zero_optimizer = LowLevelZeroOptimizer(zero_optimizer, zero_optimizer = LowLevelZeroOptimizer(zero_optimizer,
pg=pg,
overlap_communication=False, overlap_communication=False,
initial_scale=grad_scale, initial_scale=grad_scale,
reduce_bucket_size=262144, reduce_bucket_size=262144,
@ -152,7 +158,8 @@ def exam_zero_1_grad_acc():
def run_dist(rank, world_size, port): def run_dist(rank, world_size, port):
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost') colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost')
exam_zero_1_grad_acc() exam_zero_1_grad_acc(True)
exam_zero_1_grad_acc(False)
# exam_zero_1_2_grad_acc() # exam_zero_1_2_grad_acc()

View File

@ -9,6 +9,7 @@ from torch.nn.parallel import DistributedDataParallel as DDP
from torch.testing import assert_close from torch.testing import assert_close
import colossalai import colossalai
from colossalai.tensor import ProcessGroup
from colossalai.testing.random import seed_all from colossalai.testing.random import seed_all
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.zero import LowLevelZeroOptimizer from colossalai.zero import LowLevelZeroOptimizer
@ -58,14 +59,17 @@ def exam_zero_1_2():
zero1_model = TestModel().cuda() zero1_model = TestModel().cuda()
zero2_model = copy.deepcopy(zero1_model) zero2_model = copy.deepcopy(zero1_model)
pg = ProcessGroup()
# create optimizer # create optimizer
zero1_optimizer = torch.optim.Adam(zero1_model.parameters(), lr=1) zero1_optimizer = torch.optim.Adam(zero1_model.parameters(), lr=1)
zero2_optimizer = torch.optim.Adam(zero2_model.parameters(), lr=1) zero2_optimizer = torch.optim.Adam(zero2_model.parameters(), lr=1)
zero1_optimizer = LowLevelZeroOptimizer(zero1_optimizer, zero1_optimizer = LowLevelZeroOptimizer(zero1_optimizer,
pg=pg,
overlap_communication=True, overlap_communication=True,
initial_scale=128, initial_scale=128,
verbose=True) verbose=True)
zero2_optimizer = LowLevelZeroOptimizer(zero2_optimizer, zero2_optimizer = LowLevelZeroOptimizer(zero2_optimizer,
pg=pg,
overlap_communication=True, overlap_communication=True,
partition_grad=True, partition_grad=True,
initial_scale=128) initial_scale=128)
@ -127,7 +131,9 @@ def exam_zero_1_torch_ddp():
# we only test stage 1 here # we only test stage 1 here
# in `check_sharded_param_consistency.py`, we will test whether # in `check_sharded_param_consistency.py`, we will test whether
# level 1 and 2 will produce exactly the same results # level 1 and 2 will produce exactly the same results
pg = ProcessGroup()
zero_optimizer = LowLevelZeroOptimizer(zero_optimizer, zero_optimizer = LowLevelZeroOptimizer(zero_optimizer,
pg=pg,
overlap_communication=True, overlap_communication=True,
initial_scale=1, initial_scale=1,
reduce_bucket_size=262144) reduce_bucket_size=262144)