mirror of https://github.com/hpcaitech/ColossalAI
[gemini] async grad chunk reduce (all-reduce&reduce-scatter) (#5713)
* [gemini] async grad chunk reduce (all-reduce&reduce-scatter) * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [gemini] add test * [gemini] rename func * [gemini] update llama benchmark * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [gemini] use tensor counter * [gemini] change default config in GeminiPlugin and GeminiDDP * [chore] typo * [gemini] fix sync issue & add test cases * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>pull/5750/head
parent
85946d4236
commit
2fc85abf43
|
@ -361,6 +361,7 @@ class GeminiPlugin(DPPluginBase):
|
||||||
enable_sequence_parallelism: bool = False,
|
enable_sequence_parallelism: bool = False,
|
||||||
enable_jit_fused: bool = False,
|
enable_jit_fused: bool = False,
|
||||||
enable_sequence_overlap: bool = False,
|
enable_sequence_overlap: bool = False,
|
||||||
|
enable_async_reduce: bool = True,
|
||||||
verbose: bool = False,
|
verbose: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -386,6 +387,7 @@ class GeminiPlugin(DPPluginBase):
|
||||||
memstats=memstats,
|
memstats=memstats,
|
||||||
mixed_precision=PRECISION_STR_TO_DTYPE[precision],
|
mixed_precision=PRECISION_STR_TO_DTYPE[precision],
|
||||||
master_weights=master_weights,
|
master_weights=master_weights,
|
||||||
|
enable_async_reduce=enable_async_reduce,
|
||||||
)
|
)
|
||||||
self.zero_optim_config = dict(
|
self.zero_optim_config = dict(
|
||||||
gpu_margin_mem_ratio=gpu_margin_mem_ratio,
|
gpu_margin_mem_ratio=gpu_margin_mem_ratio,
|
||||||
|
|
|
@ -164,6 +164,8 @@ class Chunk:
|
||||||
self.l2_norm = None
|
self.l2_norm = None
|
||||||
|
|
||||||
self.grad_chunk = None
|
self.grad_chunk = None
|
||||||
|
# the async all-reduce/reduce-scatter work of this grad chunk (None means sync)
|
||||||
|
self.grad_reduce_work = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def memory_usage(self) -> Dict[str, int]:
|
def memory_usage(self) -> Dict[str, int]:
|
||||||
|
@ -244,7 +246,7 @@ class Chunk:
|
||||||
assert self.cuda_shard is not None # only check on CUDA
|
assert self.cuda_shard is not None # only check on CUDA
|
||||||
valid_tensor = self.cuda_shard[: self.valid_end]
|
valid_tensor = self.cuda_shard[: self.valid_end]
|
||||||
|
|
||||||
return torch.isinf(valid_tensor).any().item() | torch.isnan(valid_tensor).any().item()
|
return torch.isinf(valid_tensor).any() | torch.isnan(valid_tensor).any()
|
||||||
|
|
||||||
def set_l2_norm(self) -> None:
|
def set_l2_norm(self) -> None:
|
||||||
"""Record l2 norm of this chunks on CUDA."""
|
"""Record l2 norm of this chunks on CUDA."""
|
||||||
|
@ -374,37 +376,49 @@ class Chunk:
|
||||||
if self.is_gathered:
|
if self.is_gathered:
|
||||||
self.__scatter()
|
self.__scatter()
|
||||||
|
|
||||||
def reduce(self):
|
def reduce(self, async_op: bool = False):
|
||||||
"""Reduce scatter all the gradients. It's an operation done in CUDA."""
|
"""Reduce scatter all the gradients. It's an operation done in CUDA."""
|
||||||
# sanity check
|
# sanity check
|
||||||
assert self.is_gathered
|
assert self.is_gathered
|
||||||
|
assert self.grad_reduce_work is None
|
||||||
if self.pg_size == 1:
|
if self.pg_size == 1:
|
||||||
# tricky code here
|
# tricky code here
|
||||||
# just move cuda_global_chunk to cuda_shard
|
# just move cuda_global_chunk to cuda_shard
|
||||||
# the communication is not necessary
|
# the communication is not necessary
|
||||||
self.__scatter()
|
self.__scatter()
|
||||||
if self.extra_dp_group is not None:
|
if self.extra_dp_group is not None:
|
||||||
dist.all_reduce(self.cuda_shard, group=self.extra_dp_group)
|
self.grad_reduce_work = dist.all_reduce(self.cuda_shard, group=self.extra_dp_group, async_op=async_op)
|
||||||
elif self.keep_gathered:
|
elif self.keep_gathered:
|
||||||
# we use all-reduce here
|
# we use all-reduce here
|
||||||
dist.all_reduce(self.cuda_global_chunk, group=self.torch_pg)
|
self.grad_reduce_work = dist.all_reduce(self.cuda_global_chunk, group=self.torch_pg, async_op=async_op)
|
||||||
if self.extra_dp_group is not None:
|
if self.extra_dp_group is not None: # cannot guranatee the order of multiple all-reduce
|
||||||
dist.all_reduce(self.cuda_global_chunk, group=self.extra_dp_group)
|
self.wait_async_reduce()
|
||||||
|
self.grad_reduce_work = dist.all_reduce(
|
||||||
|
self.cuda_global_chunk, group=self.extra_dp_group, async_op=async_op
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
self.cuda_shard = torch.empty(
|
self.cuda_shard = torch.empty(
|
||||||
self.shard_size, dtype=self.dtype, device=get_accelerator().get_current_device()
|
self.shard_size, dtype=self.dtype, device=get_accelerator().get_current_device()
|
||||||
)
|
)
|
||||||
|
|
||||||
input_list = list(torch.chunk(self.cuda_global_chunk, chunks=self.pg_size, dim=0))
|
input_list = list(torch.chunk(self.cuda_global_chunk, chunks=self.pg_size, dim=0))
|
||||||
dist.reduce_scatter(self.cuda_shard, input_list, group=self.torch_pg)
|
self.grad_reduce_work = dist.reduce_scatter(
|
||||||
|
self.cuda_shard, input_list, group=self.torch_pg, async_op=async_op
|
||||||
|
)
|
||||||
|
|
||||||
if self.extra_dp_group is not None:
|
if self.extra_dp_group is not None:
|
||||||
dist.all_reduce(self.cuda_shard, group=self.extra_dp_group)
|
self.wait_async_reduce()
|
||||||
|
self.grad_reduce_work = dist.all_reduce(self.cuda_shard, group=self.extra_dp_group, async_op=async_op)
|
||||||
|
|
||||||
free_storage(self.cuda_global_chunk)
|
free_storage(self.cuda_global_chunk)
|
||||||
self.is_gathered = False
|
self.is_gathered = False
|
||||||
self.__update_tensors_state(TensorState.HOLD)
|
self.__update_tensors_state(TensorState.HOLD)
|
||||||
|
|
||||||
|
def wait_async_reduce(self) -> None:
|
||||||
|
if self.grad_reduce_work is not None:
|
||||||
|
self.grad_reduce_work.wait()
|
||||||
|
self.grad_reduce_work = None
|
||||||
|
|
||||||
def tensor_trans_state(self, tensor: torch.Tensor, tensor_state: TensorState) -> None:
|
def tensor_trans_state(self, tensor: torch.Tensor, tensor_state: TensorState) -> None:
|
||||||
"""
|
"""
|
||||||
Make a transition of the tensor into the next state.
|
Make a transition of the tensor into the next state.
|
||||||
|
|
|
@ -41,7 +41,7 @@ class ChunkManager:
|
||||||
self.reuse_fp16_chunk = reuse_fp16_chunk
|
self.reuse_fp16_chunk = reuse_fp16_chunk
|
||||||
# Whether model is accumulating gradients,
|
# Whether model is accumulating gradients,
|
||||||
self.accumulating_grads = False
|
self.accumulating_grads = False
|
||||||
self.overflow_counter = 0
|
self.overflow_counter = torch.tensor([0], dtype=torch.int, device=get_accelerator().get_current_device())
|
||||||
|
|
||||||
def register_tensor(
|
def register_tensor(
|
||||||
self,
|
self,
|
||||||
|
@ -143,12 +143,12 @@ class ChunkManager:
|
||||||
chunk = self.tensor_chunk_map[tensor]
|
chunk = self.tensor_chunk_map[tensor]
|
||||||
chunk.tensor_trans_state(tensor, state)
|
chunk.tensor_trans_state(tensor, state)
|
||||||
|
|
||||||
def reduce_chunk(self, chunk: Chunk) -> bool:
|
def reduce_chunk(self, chunk: Chunk, async_op: bool = False) -> bool:
|
||||||
"""Reduce or all reduce the chunk."""
|
"""Reduce or all reduce the chunk."""
|
||||||
if not chunk.can_reduce:
|
if not chunk.can_reduce:
|
||||||
return False
|
return False
|
||||||
self.__sub_memory_usage(chunk.memory_usage)
|
self.__sub_memory_usage(chunk.memory_usage)
|
||||||
chunk.reduce()
|
chunk.reduce(async_op=async_op)
|
||||||
self.__sub_accessed_chunk(chunk)
|
self.__sub_accessed_chunk(chunk)
|
||||||
self.__add_memory_usage(chunk.memory_usage)
|
self.__add_memory_usage(chunk.memory_usage)
|
||||||
return True
|
return True
|
||||||
|
@ -272,7 +272,7 @@ class ChunkManager:
|
||||||
return grad_chunk
|
return grad_chunk
|
||||||
|
|
||||||
def rearrange_accumulated_grad_chunk(self, chunk: Chunk) -> Chunk:
|
def rearrange_accumulated_grad_chunk(self, chunk: Chunk) -> Chunk:
|
||||||
"""Rearrange gradients accumulated in chunk.grad_chunk, and getP prepared for gradient reduction."""
|
"""Rearrange gradients accumulated in chunk.grad_chunk, and get prepared for gradient reduction."""
|
||||||
|
|
||||||
assert chunk.grad_chunk is not None
|
assert chunk.grad_chunk is not None
|
||||||
|
|
||||||
|
|
|
@ -96,6 +96,7 @@ class GeminiDDP(ModelWrapper):
|
||||||
master_weights: bool = True,
|
master_weights: bool = True,
|
||||||
extra_dp_group: Optional[ProcessGroup] = None,
|
extra_dp_group: Optional[ProcessGroup] = None,
|
||||||
verbose: bool = False,
|
verbose: bool = False,
|
||||||
|
enable_async_reduce: bool = True,
|
||||||
) -> None:
|
) -> None:
|
||||||
assert mixed_precision in (torch.float16, torch.bfloat16)
|
assert mixed_precision in (torch.float16, torch.bfloat16)
|
||||||
reuse_fp16_chunk = master_weights if not enable_gradient_accumulation else False
|
reuse_fp16_chunk = master_weights if not enable_gradient_accumulation else False
|
||||||
|
@ -178,6 +179,7 @@ class GeminiDDP(ModelWrapper):
|
||||||
if is_ddp_ignored(p):
|
if is_ddp_ignored(p):
|
||||||
continue
|
continue
|
||||||
if p.requires_grad:
|
if p.requires_grad:
|
||||||
|
assert not hasattr(p, "_grad_handle")
|
||||||
p._grad_handle = p.register_hook(
|
p._grad_handle = p.register_hook(
|
||||||
partial(
|
partial(
|
||||||
GeminiDDP.grad_handle,
|
GeminiDDP.grad_handle,
|
||||||
|
@ -187,6 +189,7 @@ class GeminiDDP(ModelWrapper):
|
||||||
master_weights=self.master_weights,
|
master_weights=self.master_weights,
|
||||||
enable_gradient_accumulation=self.enable_gradient_accumulation,
|
enable_gradient_accumulation=self.enable_gradient_accumulation,
|
||||||
p=p,
|
p=p,
|
||||||
|
async_reduce=enable_async_reduce,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -334,6 +337,11 @@ class GeminiDDP(ModelWrapper):
|
||||||
setattr(param, "_gemini_reduced", False)
|
setattr(param, "_gemini_reduced", False)
|
||||||
|
|
||||||
def _post_backward(self):
|
def _post_backward(self):
|
||||||
|
for param in self.param2name:
|
||||||
|
if hasattr(param, "_release_grad_chunk_cb"):
|
||||||
|
param._release_grad_chunk_cb()
|
||||||
|
delattr(param, "_release_grad_chunk_cb")
|
||||||
|
|
||||||
if self.chunk_manager.accessed_mem != 0:
|
if self.chunk_manager.accessed_mem != 0:
|
||||||
error_params = ["Reduction failed at followed parameters:"]
|
error_params = ["Reduction failed at followed parameters:"]
|
||||||
for param in self.param2name:
|
for param in self.param2name:
|
||||||
|
@ -371,6 +379,7 @@ class GeminiDDP(ModelWrapper):
|
||||||
master_weights: bool,
|
master_weights: bool,
|
||||||
enable_gradient_accumulation: bool,
|
enable_gradient_accumulation: bool,
|
||||||
p: nn.Parameter,
|
p: nn.Parameter,
|
||||||
|
async_reduce: bool,
|
||||||
):
|
):
|
||||||
setattr(p, "_gemini_reduced", True)
|
setattr(p, "_gemini_reduced", True)
|
||||||
empty_grad = torch.empty_like(grad)
|
empty_grad = torch.empty_like(grad)
|
||||||
|
@ -406,8 +415,35 @@ class GeminiDDP(ModelWrapper):
|
||||||
grad_chunk.copy_tensor_to_chunk_slice(p, grad, update_ptr=chunk_manager.reuse_fp16_chunk)
|
grad_chunk.copy_tensor_to_chunk_slice(p, grad, update_ptr=chunk_manager.reuse_fp16_chunk)
|
||||||
else:
|
else:
|
||||||
grad_chunk.add_tensor_to_chunk_slice(p, grad)
|
grad_chunk.add_tensor_to_chunk_slice(p, grad)
|
||||||
reduced = chunk_manager.reduce_chunk(grad_chunk)
|
reduced = chunk_manager.reduce_chunk(grad_chunk, async_op=async_reduce)
|
||||||
if reduced:
|
if reduced: # if not async, can release immediately, else release in when work finished
|
||||||
|
if async_reduce:
|
||||||
|
# dirty fix by installing callback
|
||||||
|
assert not hasattr(p, "_release_grad_chunk_cb")
|
||||||
|
|
||||||
|
def _release_grad_chunk_cb():
|
||||||
|
grad_chunk.wait_async_reduce()
|
||||||
|
GeminiDDP.release_grad_chunk_handle(
|
||||||
|
chunk_manager,
|
||||||
|
grads_device,
|
||||||
|
master_weights,
|
||||||
|
enable_gradient_accumulation,
|
||||||
|
p,
|
||||||
|
chunk,
|
||||||
|
grad_chunk,
|
||||||
|
)
|
||||||
|
|
||||||
|
p._release_grad_chunk_cb = _release_grad_chunk_cb
|
||||||
|
else:
|
||||||
|
GeminiDDP.release_grad_chunk_handle(
|
||||||
|
chunk_manager, grads_device, master_weights, enable_gradient_accumulation, p, chunk, grad_chunk
|
||||||
|
)
|
||||||
|
return empty_grad
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def release_grad_chunk_handle(
|
||||||
|
chunk_manager, grads_device, master_weights, enable_gradient_accumulation, p, chunk, grad_chunk
|
||||||
|
):
|
||||||
if not chunk_manager.reuse_fp16_chunk:
|
if not chunk_manager.reuse_fp16_chunk:
|
||||||
if chunk.keep_gathered:
|
if chunk.keep_gathered:
|
||||||
chunk_manager.fake_release_chunk(chunk)
|
chunk_manager.fake_release_chunk(chunk)
|
||||||
|
@ -429,7 +465,6 @@ class GeminiDDP(ModelWrapper):
|
||||||
chunk_manager.move_chunk(grad_chunk, grads_device[p], force_copy=True)
|
chunk_manager.move_chunk(grad_chunk, grads_device[p], force_copy=True)
|
||||||
if not (master_weights) or (enable_gradient_accumulation):
|
if not (master_weights) or (enable_gradient_accumulation):
|
||||||
chunk_manager.move_chunk(chunk, grads_device[p], force_copy=True)
|
chunk_manager.move_chunk(chunk, grads_device[p], force_copy=True)
|
||||||
return empty_grad
|
|
||||||
|
|
||||||
def zero_grad(self, set_to_none: bool = False) -> None:
|
def zero_grad(self, set_to_none: bool = False) -> None:
|
||||||
self.module.zero_grad(set_to_none=True)
|
self.module.zero_grad(set_to_none=True)
|
||||||
|
|
|
@ -62,10 +62,10 @@ class GeminiFP16MixedPrecisionMixin(FP16MixedPrecisionMixin):
|
||||||
self.module = module
|
self.module = module
|
||||||
|
|
||||||
def check_local_overflow(self) -> bool:
|
def check_local_overflow(self) -> bool:
|
||||||
return self.module.chunk_manager.overflow_counter > 0
|
return self.module.chunk_manager.overflow_counter.item() > 0
|
||||||
|
|
||||||
def pre_zero_grad(self) -> None:
|
def pre_zero_grad(self) -> None:
|
||||||
self.module.chunk_manager.overflow_counter = 0
|
self.module.chunk_manager.overflow_counter.zero_()
|
||||||
|
|
||||||
|
|
||||||
class GeminiOptimizer(OptimizerWrapper):
|
class GeminiOptimizer(OptimizerWrapper):
|
||||||
|
|
|
@ -76,6 +76,8 @@ def main():
|
||||||
parser.add_argument("--mbs", type=int, default=1, help="Micro batch size of pipeline parallel")
|
parser.add_argument("--mbs", type=int, default=1, help="Micro batch size of pipeline parallel")
|
||||||
parser.add_argument("--zero", type=int, default=0, help="Zero Stage when hybrid plugin is enabled")
|
parser.add_argument("--zero", type=int, default=0, help="Zero Stage when hybrid plugin is enabled")
|
||||||
parser.add_argument("--custom-ckpt", action="store_true", help="Customize checkpoint", default=False)
|
parser.add_argument("--custom-ckpt", action="store_true", help="Customize checkpoint", default=False)
|
||||||
|
parser.add_argument("--disable-async-reduce", action="store_true", help="Customize checkpoint", default=False)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
colossalai.launch_from_torch()
|
colossalai.launch_from_torch()
|
||||||
|
@ -110,6 +112,7 @@ def main():
|
||||||
extra_dp_size=args.extra_dp,
|
extra_dp_size=args.extra_dp,
|
||||||
enable_fused_normalization=torch.cuda.is_available(),
|
enable_fused_normalization=torch.cuda.is_available(),
|
||||||
enable_flash_attention=args.xformers,
|
enable_flash_attention=args.xformers,
|
||||||
|
enable_async_reduce=not args.disable_async_reduce,
|
||||||
)
|
)
|
||||||
elif args.plugin == "gemini_auto":
|
elif args.plugin == "gemini_auto":
|
||||||
plugin = GeminiPlugin(
|
plugin = GeminiPlugin(
|
||||||
|
|
|
@ -34,7 +34,8 @@ def check_equal(param, param_cp):
|
||||||
@parameterize("init_device", [None, torch.device("cpu")])
|
@parameterize("init_device", [None, torch.device("cpu")])
|
||||||
@parameterize("keep_gathered", [True, False])
|
@parameterize("keep_gathered", [True, False])
|
||||||
@parameterize("pin_memory", [True, False])
|
@parameterize("pin_memory", [True, False])
|
||||||
def exam_chunk_basic(init_device, keep_gathered, pin_memory):
|
@parameterize("async_op", [True, False])
|
||||||
|
def exam_chunk_basic(init_device, keep_gathered, pin_memory, async_op):
|
||||||
world_size = torch.distributed.get_world_size()
|
world_size = torch.distributed.get_world_size()
|
||||||
pg = _get_default_group()
|
pg = _get_default_group()
|
||||||
my_chunk = Chunk(
|
my_chunk = Chunk(
|
||||||
|
@ -94,9 +95,12 @@ def exam_chunk_basic(init_device, keep_gathered, pin_memory):
|
||||||
|
|
||||||
assert my_chunk.tensor_state_cnter[TensorState.READY_FOR_REDUCE] == 4
|
assert my_chunk.tensor_state_cnter[TensorState.READY_FOR_REDUCE] == 4
|
||||||
assert my_chunk.can_reduce
|
assert my_chunk.can_reduce
|
||||||
my_chunk.reduce()
|
my_chunk.reduce(async_op)
|
||||||
assert my_chunk.tensor_state_cnter[TensorState.HOLD] == 4
|
assert my_chunk.tensor_state_cnter[TensorState.HOLD] == 4
|
||||||
|
|
||||||
|
if async_op:
|
||||||
|
my_chunk.wait_async_reduce()
|
||||||
|
|
||||||
if keep_gathered is False:
|
if keep_gathered is False:
|
||||||
assert my_chunk.cuda_shard.size(0) == 1024 // world_size
|
assert my_chunk.cuda_shard.size(0) == 1024 // world_size
|
||||||
assert my_chunk.device_type == "cuda"
|
assert my_chunk.device_type == "cuda"
|
||||||
|
|
|
@ -40,12 +40,14 @@ def check_grad(model: GeminiDDP, torch_model: torch.nn.Module):
|
||||||
@parameterize("model_name", ["transformers_gpt_lm"])
|
@parameterize("model_name", ["transformers_gpt_lm"])
|
||||||
@parameterize("use_grad_checkpoint", [False, True])
|
@parameterize("use_grad_checkpoint", [False, True])
|
||||||
@parameterize("master_weights", [False, True])
|
@parameterize("master_weights", [False, True])
|
||||||
|
@parameterize("enable_async_reduce", [False, True])
|
||||||
def exam_gpt_fwd_bwd(
|
def exam_gpt_fwd_bwd(
|
||||||
placement_config,
|
placement_config,
|
||||||
keep_gather,
|
keep_gather,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
use_grad_checkpoint: bool = False,
|
use_grad_checkpoint: bool = False,
|
||||||
master_weights: bool = True,
|
master_weights: bool = True,
|
||||||
|
enable_async_reduce=True,
|
||||||
):
|
):
|
||||||
init_device = get_accelerator().get_current_device()
|
init_device = get_accelerator().get_current_device()
|
||||||
model_builder, data_gen_fn, output_transform_fn, loss_fn, *_ = next(
|
model_builder, data_gen_fn, output_transform_fn, loss_fn, *_ = next(
|
||||||
|
@ -69,7 +71,13 @@ def exam_gpt_fwd_bwd(
|
||||||
config_dict[world_size]["chunk_size"] = 5000
|
config_dict[world_size]["chunk_size"] = 5000
|
||||||
config_dict[world_size]["keep_gathered"] = keep_gather
|
config_dict[world_size]["keep_gathered"] = keep_gather
|
||||||
model = GeminiDDP(
|
model = GeminiDDP(
|
||||||
model, config_dict, init_device, pin_memory=True, **placement_config, master_weights=master_weights
|
model,
|
||||||
|
config_dict,
|
||||||
|
init_device,
|
||||||
|
pin_memory=True,
|
||||||
|
**placement_config,
|
||||||
|
master_weights=master_weights,
|
||||||
|
enable_async_reduce=enable_async_reduce,
|
||||||
)
|
)
|
||||||
optimizer = HybridAdam(model.parameters(), lr=1e-3)
|
optimizer = HybridAdam(model.parameters(), lr=1e-3)
|
||||||
zero_optim = GeminiOptimizer(optimizer, model, initial_scale=1)
|
zero_optim = GeminiOptimizer(optimizer, model, initial_scale=1)
|
||||||
|
|
|
@ -50,8 +50,14 @@ def check_grad(model: GeminiDDP, torch_model: torch.nn.Module):
|
||||||
@parameterize("model_name", ["transformers_gpt_lm"])
|
@parameterize("model_name", ["transformers_gpt_lm"])
|
||||||
@parameterize("master_weights", [False, True])
|
@parameterize("master_weights", [False, True])
|
||||||
@parameterize("use_grad_checkpoint", [False, True])
|
@parameterize("use_grad_checkpoint", [False, True])
|
||||||
|
@parameterize("enable_async_reduce", [False, True])
|
||||||
def exam_gemini_grad_acc(
|
def exam_gemini_grad_acc(
|
||||||
placement_config, keep_gathered: bool, model_name: str, master_weights: bool, use_grad_checkpoint: bool
|
placement_config,
|
||||||
|
keep_gathered: bool,
|
||||||
|
model_name: str,
|
||||||
|
master_weights: bool,
|
||||||
|
use_grad_checkpoint: bool,
|
||||||
|
enable_async_reduce: bool,
|
||||||
):
|
):
|
||||||
init_device = get_accelerator().get_current_device()
|
init_device = get_accelerator().get_current_device()
|
||||||
model_builder, data_gen_fn, output_transform_fn, loss_fn, *_ = next(
|
model_builder, data_gen_fn, output_transform_fn, loss_fn, *_ = next(
|
||||||
|
@ -81,10 +87,13 @@ def exam_gemini_grad_acc(
|
||||||
pin_memory=True,
|
pin_memory=True,
|
||||||
enable_gradient_accumulation=True,
|
enable_gradient_accumulation=True,
|
||||||
master_weights=master_weights,
|
master_weights=master_weights,
|
||||||
|
enable_async_reduce=enable_async_reduce,
|
||||||
**placement_config,
|
**placement_config,
|
||||||
)
|
)
|
||||||
optimizer = HybridAdam(gemini_model.parameters(), lr=1e-3)
|
optimizer = HybridAdam(gemini_model.parameters(), lr=1e-3)
|
||||||
gemini_optim = GeminiOptimizer(optimizer, gemini_model, initial_scale=1, max_norm=1.0)
|
gemini_optim = GeminiOptimizer(
|
||||||
|
optimizer, gemini_model, initial_scale=1, max_norm=1.0, enable_async_reduce=enable_async_reduce
|
||||||
|
)
|
||||||
|
|
||||||
rank = dist.get_rank()
|
rank = dist.get_rank()
|
||||||
|
|
||||||
|
|
|
@ -52,7 +52,8 @@ def check_param(model: GeminiDDP, torch_model: torch.nn.Module):
|
||||||
@parameterize("placement_config", PLACEMENT_CONFIGS)
|
@parameterize("placement_config", PLACEMENT_CONFIGS)
|
||||||
@parameterize("model_name", ["transformers_gpt_lm"])
|
@parameterize("model_name", ["transformers_gpt_lm"])
|
||||||
@parameterize("master_weights", [True, False])
|
@parameterize("master_weights", [True, False])
|
||||||
def exam_grad_clipping(placement_config, model_name: str, master_weights: bool):
|
@parameterize("enable_async_reduce", [False, True])
|
||||||
|
def exam_grad_clipping(placement_config, model_name: str, master_weights: bool, enable_async_reduce: bool):
|
||||||
set_seed(1912)
|
set_seed(1912)
|
||||||
model_builder, data_gen_fn, output_transform_fn, loss_fn, *_ = next(
|
model_builder, data_gen_fn, output_transform_fn, loss_fn, *_ = next(
|
||||||
iter(model_zoo.get_sub_registry(model_name).values())
|
iter(model_zoo.get_sub_registry(model_name).values())
|
||||||
|
@ -84,6 +85,7 @@ def exam_grad_clipping(placement_config, model_name: str, master_weights: bool):
|
||||||
chunk_init_device=init_device,
|
chunk_init_device=init_device,
|
||||||
pin_memory=True,
|
pin_memory=True,
|
||||||
master_weights=master_weights,
|
master_weights=master_weights,
|
||||||
|
enable_async_reduce=enable_async_reduce,
|
||||||
**placement_config,
|
**placement_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -73,7 +73,10 @@ def check_param(model: GeminiDDP, torch_model: torch.nn.Module, dtype: torch.dty
|
||||||
@parameterize("model_name", TEST_MODELS)
|
@parameterize("model_name", TEST_MODELS)
|
||||||
@parameterize("mixed_precision", [torch.half, torch.bfloat16])
|
@parameterize("mixed_precision", [torch.half, torch.bfloat16])
|
||||||
@parameterize("master_weights", [True, False])
|
@parameterize("master_weights", [True, False])
|
||||||
def exam_model_step(placement_config, model_name: str, mixed_precision: torch.dtype, master_weights: bool):
|
@parameterize("enable_async_reduce", [False, True])
|
||||||
|
def exam_model_step(
|
||||||
|
placement_config, model_name: str, mixed_precision: torch.dtype, master_weights: bool, enable_async_reduce=True
|
||||||
|
):
|
||||||
set_seed(42)
|
set_seed(42)
|
||||||
model_builder, data_gen_fn, output_transform_fn, loss_fn, *_ = next(
|
model_builder, data_gen_fn, output_transform_fn, loss_fn, *_ = next(
|
||||||
iter(model_zoo.get_sub_registry(model_name).values())
|
iter(model_zoo.get_sub_registry(model_name).values())
|
||||||
|
@ -96,7 +99,12 @@ def exam_model_step(placement_config, model_name: str, mixed_precision: torch.dt
|
||||||
config_dict[world_size]["chunk_size"] = 5000
|
config_dict[world_size]["chunk_size"] = 5000
|
||||||
config_dict[world_size]["keep_gathered"] = False
|
config_dict[world_size]["keep_gathered"] = False
|
||||||
model = GeminiDDP(
|
model = GeminiDDP(
|
||||||
model, config_dict, **placement_config, mixed_precision=mixed_precision, master_weights=master_weights
|
model,
|
||||||
|
config_dict,
|
||||||
|
**placement_config,
|
||||||
|
mixed_precision=mixed_precision,
|
||||||
|
master_weights=master_weights,
|
||||||
|
enable_async_reduce=enable_async_reduce,
|
||||||
)
|
)
|
||||||
|
|
||||||
optimizer = HybridAdam(model.parameters(), lr=1e-3)
|
optimizer = HybridAdam(model.parameters(), lr=1e-3)
|
||||||
|
|
Loading…
Reference in New Issue