Browse Source

[gemini] async grad chunk reduce (all-reduce&reduce-scatter) (#5713)

* [gemini] async grad chunk reduce (all-reduce&reduce-scatter)

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [gemini] add test

* [gemini] rename func

* [gemini] update llama benchmark

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [gemini] use tensor counter

* [gemini] change default config in GeminiPlugin and GeminiDDP

* [chore] typo

* [gemini] fix sync issue & add test cases

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
pull/5750/head
botbw 6 months ago committed by GitHub
parent
commit
2fc85abf43
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 2
      colossalai/booster/plugin/gemini_plugin.py
  2. 32
      colossalai/zero/gemini/chunk/chunk.py
  3. 8
      colossalai/zero/gemini/chunk/manager.py
  4. 79
      colossalai/zero/gemini/gemini_ddp.py
  5. 4
      colossalai/zero/gemini/gemini_optimizer.py
  6. 3
      examples/language/llama/benchmark.py
  7. 8
      tests/test_zero/test_gemini/test_chunkv2.py
  8. 10
      tests/test_zero/test_gemini/test_fwd_bwd.py
  9. 13
      tests/test_zero/test_gemini/test_grad_accum.py
  10. 4
      tests/test_zero/test_gemini/test_grad_clip.py
  11. 12
      tests/test_zero/test_gemini/test_optim.py

2
colossalai/booster/plugin/gemini_plugin.py

@ -361,6 +361,7 @@ class GeminiPlugin(DPPluginBase):
enable_sequence_parallelism: bool = False, enable_sequence_parallelism: bool = False,
enable_jit_fused: bool = False, enable_jit_fused: bool = False,
enable_sequence_overlap: bool = False, enable_sequence_overlap: bool = False,
enable_async_reduce: bool = True,
verbose: bool = False, verbose: bool = False,
) -> None: ) -> None:
super().__init__() super().__init__()
@ -386,6 +387,7 @@ class GeminiPlugin(DPPluginBase):
memstats=memstats, memstats=memstats,
mixed_precision=PRECISION_STR_TO_DTYPE[precision], mixed_precision=PRECISION_STR_TO_DTYPE[precision],
master_weights=master_weights, master_weights=master_weights,
enable_async_reduce=enable_async_reduce,
) )
self.zero_optim_config = dict( self.zero_optim_config = dict(
gpu_margin_mem_ratio=gpu_margin_mem_ratio, gpu_margin_mem_ratio=gpu_margin_mem_ratio,

32
colossalai/zero/gemini/chunk/chunk.py

@ -164,6 +164,8 @@ class Chunk:
self.l2_norm = None self.l2_norm = None
self.grad_chunk = None self.grad_chunk = None
# the async all-reduce/reduce-scatter work of this grad chunk (None means sync)
self.grad_reduce_work = None
@property @property
def memory_usage(self) -> Dict[str, int]: def memory_usage(self) -> Dict[str, int]:
@ -244,7 +246,7 @@ class Chunk:
assert self.cuda_shard is not None # only check on CUDA assert self.cuda_shard is not None # only check on CUDA
valid_tensor = self.cuda_shard[: self.valid_end] valid_tensor = self.cuda_shard[: self.valid_end]
return torch.isinf(valid_tensor).any().item() | torch.isnan(valid_tensor).any().item() return torch.isinf(valid_tensor).any() | torch.isnan(valid_tensor).any()
def set_l2_norm(self) -> None: def set_l2_norm(self) -> None:
"""Record l2 norm of this chunks on CUDA.""" """Record l2 norm of this chunks on CUDA."""
@ -374,37 +376,49 @@ class Chunk:
if self.is_gathered: if self.is_gathered:
self.__scatter() self.__scatter()
def reduce(self): def reduce(self, async_op: bool = False):
"""Reduce scatter all the gradients. It's an operation done in CUDA.""" """Reduce scatter all the gradients. It's an operation done in CUDA."""
# sanity check # sanity check
assert self.is_gathered assert self.is_gathered
assert self.grad_reduce_work is None
if self.pg_size == 1: if self.pg_size == 1:
# tricky code here # tricky code here
# just move cuda_global_chunk to cuda_shard # just move cuda_global_chunk to cuda_shard
# the communication is not necessary # the communication is not necessary
self.__scatter() self.__scatter()
if self.extra_dp_group is not None: if self.extra_dp_group is not None:
dist.all_reduce(self.cuda_shard, group=self.extra_dp_group) self.grad_reduce_work = dist.all_reduce(self.cuda_shard, group=self.extra_dp_group, async_op=async_op)
elif self.keep_gathered: elif self.keep_gathered:
# we use all-reduce here # we use all-reduce here
dist.all_reduce(self.cuda_global_chunk, group=self.torch_pg) self.grad_reduce_work = dist.all_reduce(self.cuda_global_chunk, group=self.torch_pg, async_op=async_op)
if self.extra_dp_group is not None: if self.extra_dp_group is not None: # cannot guranatee the order of multiple all-reduce
dist.all_reduce(self.cuda_global_chunk, group=self.extra_dp_group) self.wait_async_reduce()
self.grad_reduce_work = dist.all_reduce(
self.cuda_global_chunk, group=self.extra_dp_group, async_op=async_op
)
else: else:
self.cuda_shard = torch.empty( self.cuda_shard = torch.empty(
self.shard_size, dtype=self.dtype, device=get_accelerator().get_current_device() self.shard_size, dtype=self.dtype, device=get_accelerator().get_current_device()
) )
input_list = list(torch.chunk(self.cuda_global_chunk, chunks=self.pg_size, dim=0)) input_list = list(torch.chunk(self.cuda_global_chunk, chunks=self.pg_size, dim=0))
dist.reduce_scatter(self.cuda_shard, input_list, group=self.torch_pg) self.grad_reduce_work = dist.reduce_scatter(
self.cuda_shard, input_list, group=self.torch_pg, async_op=async_op
)
if self.extra_dp_group is not None: if self.extra_dp_group is not None:
dist.all_reduce(self.cuda_shard, group=self.extra_dp_group) self.wait_async_reduce()
self.grad_reduce_work = dist.all_reduce(self.cuda_shard, group=self.extra_dp_group, async_op=async_op)
free_storage(self.cuda_global_chunk) free_storage(self.cuda_global_chunk)
self.is_gathered = False self.is_gathered = False
self.__update_tensors_state(TensorState.HOLD) self.__update_tensors_state(TensorState.HOLD)
def wait_async_reduce(self) -> None:
if self.grad_reduce_work is not None:
self.grad_reduce_work.wait()
self.grad_reduce_work = None
def tensor_trans_state(self, tensor: torch.Tensor, tensor_state: TensorState) -> None: def tensor_trans_state(self, tensor: torch.Tensor, tensor_state: TensorState) -> None:
""" """
Make a transition of the tensor into the next state. Make a transition of the tensor into the next state.

8
colossalai/zero/gemini/chunk/manager.py

@ -41,7 +41,7 @@ class ChunkManager:
self.reuse_fp16_chunk = reuse_fp16_chunk self.reuse_fp16_chunk = reuse_fp16_chunk
# Whether model is accumulating gradients, # Whether model is accumulating gradients,
self.accumulating_grads = False self.accumulating_grads = False
self.overflow_counter = 0 self.overflow_counter = torch.tensor([0], dtype=torch.int, device=get_accelerator().get_current_device())
def register_tensor( def register_tensor(
self, self,
@ -143,12 +143,12 @@ class ChunkManager:
chunk = self.tensor_chunk_map[tensor] chunk = self.tensor_chunk_map[tensor]
chunk.tensor_trans_state(tensor, state) chunk.tensor_trans_state(tensor, state)
def reduce_chunk(self, chunk: Chunk) -> bool: def reduce_chunk(self, chunk: Chunk, async_op: bool = False) -> bool:
"""Reduce or all reduce the chunk.""" """Reduce or all reduce the chunk."""
if not chunk.can_reduce: if not chunk.can_reduce:
return False return False
self.__sub_memory_usage(chunk.memory_usage) self.__sub_memory_usage(chunk.memory_usage)
chunk.reduce() chunk.reduce(async_op=async_op)
self.__sub_accessed_chunk(chunk) self.__sub_accessed_chunk(chunk)
self.__add_memory_usage(chunk.memory_usage) self.__add_memory_usage(chunk.memory_usage)
return True return True
@ -272,7 +272,7 @@ class ChunkManager:
return grad_chunk return grad_chunk
def rearrange_accumulated_grad_chunk(self, chunk: Chunk) -> Chunk: def rearrange_accumulated_grad_chunk(self, chunk: Chunk) -> Chunk:
"""Rearrange gradients accumulated in chunk.grad_chunk, and getP prepared for gradient reduction.""" """Rearrange gradients accumulated in chunk.grad_chunk, and get prepared for gradient reduction."""
assert chunk.grad_chunk is not None assert chunk.grad_chunk is not None

79
colossalai/zero/gemini/gemini_ddp.py

@ -96,6 +96,7 @@ class GeminiDDP(ModelWrapper):
master_weights: bool = True, master_weights: bool = True,
extra_dp_group: Optional[ProcessGroup] = None, extra_dp_group: Optional[ProcessGroup] = None,
verbose: bool = False, verbose: bool = False,
enable_async_reduce: bool = True,
) -> None: ) -> None:
assert mixed_precision in (torch.float16, torch.bfloat16) assert mixed_precision in (torch.float16, torch.bfloat16)
reuse_fp16_chunk = master_weights if not enable_gradient_accumulation else False reuse_fp16_chunk = master_weights if not enable_gradient_accumulation else False
@ -178,6 +179,7 @@ class GeminiDDP(ModelWrapper):
if is_ddp_ignored(p): if is_ddp_ignored(p):
continue continue
if p.requires_grad: if p.requires_grad:
assert not hasattr(p, "_grad_handle")
p._grad_handle = p.register_hook( p._grad_handle = p.register_hook(
partial( partial(
GeminiDDP.grad_handle, GeminiDDP.grad_handle,
@ -187,6 +189,7 @@ class GeminiDDP(ModelWrapper):
master_weights=self.master_weights, master_weights=self.master_weights,
enable_gradient_accumulation=self.enable_gradient_accumulation, enable_gradient_accumulation=self.enable_gradient_accumulation,
p=p, p=p,
async_reduce=enable_async_reduce,
) )
) )
@ -334,6 +337,11 @@ class GeminiDDP(ModelWrapper):
setattr(param, "_gemini_reduced", False) setattr(param, "_gemini_reduced", False)
def _post_backward(self): def _post_backward(self):
for param in self.param2name:
if hasattr(param, "_release_grad_chunk_cb"):
param._release_grad_chunk_cb()
delattr(param, "_release_grad_chunk_cb")
if self.chunk_manager.accessed_mem != 0: if self.chunk_manager.accessed_mem != 0:
error_params = ["Reduction failed at followed parameters:"] error_params = ["Reduction failed at followed parameters:"]
for param in self.param2name: for param in self.param2name:
@ -371,6 +379,7 @@ class GeminiDDP(ModelWrapper):
master_weights: bool, master_weights: bool,
enable_gradient_accumulation: bool, enable_gradient_accumulation: bool,
p: nn.Parameter, p: nn.Parameter,
async_reduce: bool,
): ):
setattr(p, "_gemini_reduced", True) setattr(p, "_gemini_reduced", True)
empty_grad = torch.empty_like(grad) empty_grad = torch.empty_like(grad)
@ -406,31 +415,57 @@ class GeminiDDP(ModelWrapper):
grad_chunk.copy_tensor_to_chunk_slice(p, grad, update_ptr=chunk_manager.reuse_fp16_chunk) grad_chunk.copy_tensor_to_chunk_slice(p, grad, update_ptr=chunk_manager.reuse_fp16_chunk)
else: else:
grad_chunk.add_tensor_to_chunk_slice(p, grad) grad_chunk.add_tensor_to_chunk_slice(p, grad)
reduced = chunk_manager.reduce_chunk(grad_chunk) reduced = chunk_manager.reduce_chunk(grad_chunk, async_op=async_reduce)
if reduced: if reduced: # if not async, can release immediately, else release in when work finished
if not chunk_manager.reuse_fp16_chunk: if async_reduce:
if chunk.keep_gathered: # dirty fix by installing callback
chunk_manager.fake_release_chunk(chunk) assert not hasattr(p, "_release_grad_chunk_cb")
else:
chunk_manager.release_chunk(chunk) def _release_grad_chunk_cb():
if grad_chunk.is_gathered: grad_chunk.wait_async_reduce()
grad_chunk.cuda_global_chunk.div_(chunk.pg_size) GeminiDDP.release_grad_chunk_handle(
if chunk.extra_dp_group is not None: chunk_manager,
grad_chunk.cuda_global_chunk.div_(chunk.extra_dp_size) grads_device,
master_weights,
enable_gradient_accumulation,
p,
chunk,
grad_chunk,
)
p._release_grad_chunk_cb = _release_grad_chunk_cb
else: else:
grad_chunk.cuda_shard.div_(chunk.pg_size) GeminiDDP.release_grad_chunk_handle(
if chunk.extra_dp_group is not None: chunk_manager, grads_device, master_weights, enable_gradient_accumulation, p, chunk, grad_chunk
grad_chunk.cuda_shard.div_(chunk.extra_dp_size) )
# check overflow elements
chunk_manager.overflow_counter += grad_chunk.has_inf_or_nan
# record l2 norm for gradient clipping. flag is bound to fp16 chunk
if chunk.l2_norm_flag:
grad_chunk.set_l2_norm()
chunk_manager.move_chunk(grad_chunk, grads_device[p], force_copy=True)
if not (master_weights) or (enable_gradient_accumulation):
chunk_manager.move_chunk(chunk, grads_device[p], force_copy=True)
return empty_grad return empty_grad
@staticmethod
def release_grad_chunk_handle(
chunk_manager, grads_device, master_weights, enable_gradient_accumulation, p, chunk, grad_chunk
):
if not chunk_manager.reuse_fp16_chunk:
if chunk.keep_gathered:
chunk_manager.fake_release_chunk(chunk)
else:
chunk_manager.release_chunk(chunk)
if grad_chunk.is_gathered:
grad_chunk.cuda_global_chunk.div_(chunk.pg_size)
if chunk.extra_dp_group is not None:
grad_chunk.cuda_global_chunk.div_(chunk.extra_dp_size)
else:
grad_chunk.cuda_shard.div_(chunk.pg_size)
if chunk.extra_dp_group is not None:
grad_chunk.cuda_shard.div_(chunk.extra_dp_size)
# check overflow elements
chunk_manager.overflow_counter += grad_chunk.has_inf_or_nan
# record l2 norm for gradient clipping. flag is bound to fp16 chunk
if chunk.l2_norm_flag:
grad_chunk.set_l2_norm()
chunk_manager.move_chunk(grad_chunk, grads_device[p], force_copy=True)
if not (master_weights) or (enable_gradient_accumulation):
chunk_manager.move_chunk(chunk, grads_device[p], force_copy=True)
def zero_grad(self, set_to_none: bool = False) -> None: def zero_grad(self, set_to_none: bool = False) -> None:
self.module.zero_grad(set_to_none=True) self.module.zero_grad(set_to_none=True)

4
colossalai/zero/gemini/gemini_optimizer.py

@ -62,10 +62,10 @@ class GeminiFP16MixedPrecisionMixin(FP16MixedPrecisionMixin):
self.module = module self.module = module
def check_local_overflow(self) -> bool: def check_local_overflow(self) -> bool:
return self.module.chunk_manager.overflow_counter > 0 return self.module.chunk_manager.overflow_counter.item() > 0
def pre_zero_grad(self) -> None: def pre_zero_grad(self) -> None:
self.module.chunk_manager.overflow_counter = 0 self.module.chunk_manager.overflow_counter.zero_()
class GeminiOptimizer(OptimizerWrapper): class GeminiOptimizer(OptimizerWrapper):

3
examples/language/llama/benchmark.py

@ -76,6 +76,8 @@ def main():
parser.add_argument("--mbs", type=int, default=1, help="Micro batch size of pipeline parallel") parser.add_argument("--mbs", type=int, default=1, help="Micro batch size of pipeline parallel")
parser.add_argument("--zero", type=int, default=0, help="Zero Stage when hybrid plugin is enabled") parser.add_argument("--zero", type=int, default=0, help="Zero Stage when hybrid plugin is enabled")
parser.add_argument("--custom-ckpt", action="store_true", help="Customize checkpoint", default=False) parser.add_argument("--custom-ckpt", action="store_true", help="Customize checkpoint", default=False)
parser.add_argument("--disable-async-reduce", action="store_true", help="Customize checkpoint", default=False)
args = parser.parse_args() args = parser.parse_args()
colossalai.launch_from_torch() colossalai.launch_from_torch()
@ -110,6 +112,7 @@ def main():
extra_dp_size=args.extra_dp, extra_dp_size=args.extra_dp,
enable_fused_normalization=torch.cuda.is_available(), enable_fused_normalization=torch.cuda.is_available(),
enable_flash_attention=args.xformers, enable_flash_attention=args.xformers,
enable_async_reduce=not args.disable_async_reduce,
) )
elif args.plugin == "gemini_auto": elif args.plugin == "gemini_auto":
plugin = GeminiPlugin( plugin = GeminiPlugin(

8
tests/test_zero/test_gemini/test_chunkv2.py

@ -34,7 +34,8 @@ def check_equal(param, param_cp):
@parameterize("init_device", [None, torch.device("cpu")]) @parameterize("init_device", [None, torch.device("cpu")])
@parameterize("keep_gathered", [True, False]) @parameterize("keep_gathered", [True, False])
@parameterize("pin_memory", [True, False]) @parameterize("pin_memory", [True, False])
def exam_chunk_basic(init_device, keep_gathered, pin_memory): @parameterize("async_op", [True, False])
def exam_chunk_basic(init_device, keep_gathered, pin_memory, async_op):
world_size = torch.distributed.get_world_size() world_size = torch.distributed.get_world_size()
pg = _get_default_group() pg = _get_default_group()
my_chunk = Chunk( my_chunk = Chunk(
@ -94,9 +95,12 @@ def exam_chunk_basic(init_device, keep_gathered, pin_memory):
assert my_chunk.tensor_state_cnter[TensorState.READY_FOR_REDUCE] == 4 assert my_chunk.tensor_state_cnter[TensorState.READY_FOR_REDUCE] == 4
assert my_chunk.can_reduce assert my_chunk.can_reduce
my_chunk.reduce() my_chunk.reduce(async_op)
assert my_chunk.tensor_state_cnter[TensorState.HOLD] == 4 assert my_chunk.tensor_state_cnter[TensorState.HOLD] == 4
if async_op:
my_chunk.wait_async_reduce()
if keep_gathered is False: if keep_gathered is False:
assert my_chunk.cuda_shard.size(0) == 1024 // world_size assert my_chunk.cuda_shard.size(0) == 1024 // world_size
assert my_chunk.device_type == "cuda" assert my_chunk.device_type == "cuda"

10
tests/test_zero/test_gemini/test_fwd_bwd.py

@ -40,12 +40,14 @@ def check_grad(model: GeminiDDP, torch_model: torch.nn.Module):
@parameterize("model_name", ["transformers_gpt_lm"]) @parameterize("model_name", ["transformers_gpt_lm"])
@parameterize("use_grad_checkpoint", [False, True]) @parameterize("use_grad_checkpoint", [False, True])
@parameterize("master_weights", [False, True]) @parameterize("master_weights", [False, True])
@parameterize("enable_async_reduce", [False, True])
def exam_gpt_fwd_bwd( def exam_gpt_fwd_bwd(
placement_config, placement_config,
keep_gather, keep_gather,
model_name: str, model_name: str,
use_grad_checkpoint: bool = False, use_grad_checkpoint: bool = False,
master_weights: bool = True, master_weights: bool = True,
enable_async_reduce=True,
): ):
init_device = get_accelerator().get_current_device() init_device = get_accelerator().get_current_device()
model_builder, data_gen_fn, output_transform_fn, loss_fn, *_ = next( model_builder, data_gen_fn, output_transform_fn, loss_fn, *_ = next(
@ -69,7 +71,13 @@ def exam_gpt_fwd_bwd(
config_dict[world_size]["chunk_size"] = 5000 config_dict[world_size]["chunk_size"] = 5000
config_dict[world_size]["keep_gathered"] = keep_gather config_dict[world_size]["keep_gathered"] = keep_gather
model = GeminiDDP( model = GeminiDDP(
model, config_dict, init_device, pin_memory=True, **placement_config, master_weights=master_weights model,
config_dict,
init_device,
pin_memory=True,
**placement_config,
master_weights=master_weights,
enable_async_reduce=enable_async_reduce,
) )
optimizer = HybridAdam(model.parameters(), lr=1e-3) optimizer = HybridAdam(model.parameters(), lr=1e-3)
zero_optim = GeminiOptimizer(optimizer, model, initial_scale=1) zero_optim = GeminiOptimizer(optimizer, model, initial_scale=1)

13
tests/test_zero/test_gemini/test_grad_accum.py

@ -50,8 +50,14 @@ def check_grad(model: GeminiDDP, torch_model: torch.nn.Module):
@parameterize("model_name", ["transformers_gpt_lm"]) @parameterize("model_name", ["transformers_gpt_lm"])
@parameterize("master_weights", [False, True]) @parameterize("master_weights", [False, True])
@parameterize("use_grad_checkpoint", [False, True]) @parameterize("use_grad_checkpoint", [False, True])
@parameterize("enable_async_reduce", [False, True])
def exam_gemini_grad_acc( def exam_gemini_grad_acc(
placement_config, keep_gathered: bool, model_name: str, master_weights: bool, use_grad_checkpoint: bool placement_config,
keep_gathered: bool,
model_name: str,
master_weights: bool,
use_grad_checkpoint: bool,
enable_async_reduce: bool,
): ):
init_device = get_accelerator().get_current_device() init_device = get_accelerator().get_current_device()
model_builder, data_gen_fn, output_transform_fn, loss_fn, *_ = next( model_builder, data_gen_fn, output_transform_fn, loss_fn, *_ = next(
@ -81,10 +87,13 @@ def exam_gemini_grad_acc(
pin_memory=True, pin_memory=True,
enable_gradient_accumulation=True, enable_gradient_accumulation=True,
master_weights=master_weights, master_weights=master_weights,
enable_async_reduce=enable_async_reduce,
**placement_config, **placement_config,
) )
optimizer = HybridAdam(gemini_model.parameters(), lr=1e-3) optimizer = HybridAdam(gemini_model.parameters(), lr=1e-3)
gemini_optim = GeminiOptimizer(optimizer, gemini_model, initial_scale=1, max_norm=1.0) gemini_optim = GeminiOptimizer(
optimizer, gemini_model, initial_scale=1, max_norm=1.0, enable_async_reduce=enable_async_reduce
)
rank = dist.get_rank() rank = dist.get_rank()

4
tests/test_zero/test_gemini/test_grad_clip.py

@ -52,7 +52,8 @@ def check_param(model: GeminiDDP, torch_model: torch.nn.Module):
@parameterize("placement_config", PLACEMENT_CONFIGS) @parameterize("placement_config", PLACEMENT_CONFIGS)
@parameterize("model_name", ["transformers_gpt_lm"]) @parameterize("model_name", ["transformers_gpt_lm"])
@parameterize("master_weights", [True, False]) @parameterize("master_weights", [True, False])
def exam_grad_clipping(placement_config, model_name: str, master_weights: bool): @parameterize("enable_async_reduce", [False, True])
def exam_grad_clipping(placement_config, model_name: str, master_weights: bool, enable_async_reduce: bool):
set_seed(1912) set_seed(1912)
model_builder, data_gen_fn, output_transform_fn, loss_fn, *_ = next( model_builder, data_gen_fn, output_transform_fn, loss_fn, *_ = next(
iter(model_zoo.get_sub_registry(model_name).values()) iter(model_zoo.get_sub_registry(model_name).values())
@ -84,6 +85,7 @@ def exam_grad_clipping(placement_config, model_name: str, master_weights: bool):
chunk_init_device=init_device, chunk_init_device=init_device,
pin_memory=True, pin_memory=True,
master_weights=master_weights, master_weights=master_weights,
enable_async_reduce=enable_async_reduce,
**placement_config, **placement_config,
) )

12
tests/test_zero/test_gemini/test_optim.py

@ -73,7 +73,10 @@ def check_param(model: GeminiDDP, torch_model: torch.nn.Module, dtype: torch.dty
@parameterize("model_name", TEST_MODELS) @parameterize("model_name", TEST_MODELS)
@parameterize("mixed_precision", [torch.half, torch.bfloat16]) @parameterize("mixed_precision", [torch.half, torch.bfloat16])
@parameterize("master_weights", [True, False]) @parameterize("master_weights", [True, False])
def exam_model_step(placement_config, model_name: str, mixed_precision: torch.dtype, master_weights: bool): @parameterize("enable_async_reduce", [False, True])
def exam_model_step(
placement_config, model_name: str, mixed_precision: torch.dtype, master_weights: bool, enable_async_reduce=True
):
set_seed(42) set_seed(42)
model_builder, data_gen_fn, output_transform_fn, loss_fn, *_ = next( model_builder, data_gen_fn, output_transform_fn, loss_fn, *_ = next(
iter(model_zoo.get_sub_registry(model_name).values()) iter(model_zoo.get_sub_registry(model_name).values())
@ -96,7 +99,12 @@ def exam_model_step(placement_config, model_name: str, mixed_precision: torch.dt
config_dict[world_size]["chunk_size"] = 5000 config_dict[world_size]["chunk_size"] = 5000
config_dict[world_size]["keep_gathered"] = False config_dict[world_size]["keep_gathered"] = False
model = GeminiDDP( model = GeminiDDP(
model, config_dict, **placement_config, mixed_precision=mixed_precision, master_weights=master_weights model,
config_dict,
**placement_config,
mixed_precision=mixed_precision,
master_weights=master_weights,
enable_async_reduce=enable_async_reduce,
) )
optimizer = HybridAdam(model.parameters(), lr=1e-3) optimizer = HybridAdam(model.parameters(), lr=1e-3)

Loading…
Cancel
Save