Merge pull request #5749 from hpcaitech/prefetch

[Gemini] Prefetch next chunk before each op
pull/5760/head
botbw 2024-05-29 15:35:54 +08:00 committed by GitHub
commit 023ea13cb5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 239 additions and 65 deletions

View File

@ -329,6 +329,7 @@ class GeminiPlugin(DPPluginBase):
chunk_init_device: Optional[torch.device] = None,
placement_policy: str = "static",
enable_gradient_accumulation: bool = False,
max_prefetch: int = 0,
shard_param_frac: float = 1.0, # only for static placement
offload_optim_frac: float = 0.0, # only for static placement
offload_param_frac: float = 0.0, # only for static placement
@ -387,6 +388,7 @@ class GeminiPlugin(DPPluginBase):
memstats=memstats,
mixed_precision=PRECISION_STR_TO_DTYPE[precision],
master_weights=master_weights,
max_prefetch=max_prefetch,
enable_async_reduce=enable_async_reduce,
)
self.zero_optim_config = dict(

View File

@ -359,14 +359,15 @@ class Chunk:
else:
raise NotImplementedError
def access_chunk(self):
def access_chunk(self, async_access: bool = False) -> Optional[dist.Work]:
"""Make the chunk usable for the parameters inside it. It's an operation done in CUDA."""
# sanity check
assert self.chunk_temp is None
maybe_work = None
if not self.is_gathered:
self.__gather()
maybe_work = self.__gather(async_op=async_access)
self.__update_tensors_ptr()
return maybe_work
def release_chunk(self):
"""Release the usable chunk. It's an operation done in CUDA."""
@ -512,17 +513,19 @@ class Chunk:
def get_tensors(self) -> List[torch.Tensor]:
return list(self.tensors_info.keys())
def __gather(self):
def __gather(self, async_op: bool = False) -> Optional[dist.Work]:
if not self.is_gathered:
# sanity check
assert self.cuda_shard is not None
alloc_storage(self.cuda_global_chunk)
gather_list = list(torch.chunk(input=self.cuda_global_chunk, chunks=self.pg_size, dim=0))
dist.all_gather(gather_list, self.cuda_shard, self.torch_pg)
work = dist.all_gather(gather_list, self.cuda_shard, self.torch_pg, async_op=async_op)
self.cuda_shard = None
self.is_gathered = True
return work
return None
def __scatter(self):
if self.keep_gathered:

View File

@ -111,15 +111,16 @@ class ChunkManager:
for group_name in self.chunk_groups:
self.__close_one_chunk(self.chunk_groups[group_name][-1])
def access_chunk(self, chunk: Chunk) -> None:
def access_chunk(self, chunk: Chunk, async_access: bool = False) -> Optional[dist.Work]:
"""Make the chunk can be used for calculation."""
if chunk in self.accessed_chunks:
return
return None
self.__sub_memory_usage(chunk.memory_usage)
if chunk.device_type == "cpu":
chunk.shard_move(get_accelerator().get_current_device())
self.__add_accessed_chunk(chunk)
maybe_work = self.__add_accessed_chunk(chunk, async_access=async_access)
self.__add_memory_usage(chunk.memory_usage)
return maybe_work
def release_chunk(self, chunk: Chunk) -> None:
"""Scatter the chunk in CUDA."""
@ -251,10 +252,11 @@ class ChunkManager:
for k, v in usage.items():
self.total_mem[k] += v
def __add_accessed_chunk(self, chunk: Chunk):
chunk.access_chunk()
def __add_accessed_chunk(self, chunk: Chunk, async_access: bool = False) -> Optional[dist.Work]:
maybe_work = chunk.access_chunk(async_access=async_access)
self.accessed_chunks.add(chunk)
self.accessed_mem += chunk.chunk_mem
return maybe_work
def __sub_accessed_chunk(self, chunk: Chunk):
chunk.release_chunk()

View File

@ -78,6 +78,7 @@ class GeminiDDP(ModelWrapper):
chunk_init_device: torch.device = torch.device("cpu"),
placement_policy: str = "static",
enable_gradient_accumulation: bool = False,
max_prefetch: int = 0,
shard_param_frac: float = 1.0, # only for static placement
offload_optim_frac: float = 0.0, # only for static placement
offload_param_frac: float = 0.0, # only for static placement
@ -131,6 +132,7 @@ class GeminiDDP(ModelWrapper):
offload_param_frac=offload_param_frac,
warmup_non_model_data_ratio=warmup_non_model_data_ratio,
steady_cuda_cap_ratio=steady_cuda_cap_ratio,
max_prefetch=max_prefetch,
)
self.force_outputs_fp32 = force_outputs_fp32
self.param_op_hook = GeminiZeROHook(self.gemini_manager)

View File

@ -24,16 +24,42 @@ class GeminiZeROHook(ColoParamOpHook):
self._training_phase = TrainingPhase.FORWARD
def pre_op(self, params):
# map params to chunks
params = [p for p in params if not is_ddp_ignored(p)]
chunks = self._chunk_manager.get_chunks(params)
all_chunks = self._chunk_manager.get_chunks(params)
# wait for prefetched chunks, filter those are not prefetched
chunks_fetch_sync = self._gemini_manager.wait_chunks(all_chunks)
# transfer state
for p in params:
self._chunk_manager.trans_tensor_state(p, TensorState.COMPUTE)
self._gemini_manager.sample_overall_data()
self._gemini_manager.adjust_layout(chunks)
for chunk in chunks:
# evit chunks, aware of async fetched
self._gemini_manager.adjust_layout(
all_chunks, record_anyway=self._gemini_manager.placement_policy.max_prefetch > 0
)
# fetch the rest synchronously
for chunk in chunks_fetch_sync:
self._chunk_manager.access_chunk(chunk)
# record cuda model data of the current OP
# get possible chunks to prefetch
chunks_fetch_async = self._gemini_manager.placement_policy.get_prefetch_chunks(
is_warmup=self._gemini_manager.is_warmup(),
compute_list=self._gemini_manager.compute_list,
compute_idx=self._gemini_manager.compute_idx,
async_works=self._gemini_manager.async_works,
)
# prefetch
for chunk in chunks_fetch_async:
maybe_work = self._chunk_manager.access_chunk(chunk, async_access=True)
if maybe_work is not None:
self._gemini_manager.add_work(chunk, maybe_work)
# record cuda model data of the current OP, including memory for prefetched chunks
self._gemini_manager.record_model_data_volume()
def post_op(self, params):

View File

@ -1,12 +1,13 @@
import functools
from time import time
from typing import Dict, List, Optional, Tuple
from typing import Dict, Iterable, List, Optional, Tuple
import torch
import torch.distributed as dist
from .chunk import Chunk, ChunkManager
from .memory_tracer import ChunkMemStatsCollector, MemStats
from .placement_policy import PlacementPolicyFactory
from .placement_policy import PlacementPolicy, PlacementPolicyFactory
class GeminiManager:
@ -41,9 +42,12 @@ class GeminiManager:
self._mem_stats_collector = (
ChunkMemStatsCollector(chunk_manager, self._memstats) if policy_cls.need_mem_stats else None
)
self._placement_policy = policy_cls(chunk_manager, self._mem_stats_collector, **placement_kwargs)
self._placement_policy = policy_cls(
chunk_manager=chunk_manager, mem_stats_collector=self._mem_stats_collector, **placement_kwargs
)
self._compute_list: List[Tuple[Chunk, ...]] = []
self._compute_idx: int = -1
self._async_works: Dict[Chunk, dist.Work] = {}
self._h2d_volume = 0
self._d2h_volume = 0
@ -91,18 +95,20 @@ class GeminiManager:
self._warmup = False
self.reset_attributes()
def adjust_layout(self, chunks: Tuple[Chunk, ...]) -> None:
def adjust_layout(self, chunks: Tuple[Chunk, ...], record_anyway: bool = False) -> None:
"""Adjust the layout of stateful tensors according to the information provided
by mem_stats_collector, which should belongs to a Sharded Model.
"""
# find stateful tensor in state COMPUTE
start = time()
self._record_chunks_order(chunks)
cuda_demand, hold_cuda_tensor_list = self._get_layout_info(self._compute_idx, self._warmup, chunks)
self._record_warmup_chunks_order(chunks, record_anyway=record_anyway)
cuda_demand, can_evict_chunks = self._get_layout_info(self._compute_idx, self._warmup, chunks)
# don't evict chunks that are asynchronously fetched
can_evict_chunks = [chunk for chunk in can_evict_chunks if chunk not in self._async_works]
self._layout_time += time() - start
vol, evict_time = self._placement_policy.evict_tensors(
can_evict_chunks=hold_cuda_tensor_list,
can_evict_chunks=can_evict_chunks,
cuda_demand=cuda_demand,
warmup=self._warmup,
compute_list=self._compute_list,
@ -114,6 +120,21 @@ class GeminiManager:
# move COMPUTE tensors to CUDA
self._h2d_volume += cuda_demand
def wait_chunks(self, chunks: Iterable[Chunk]) -> Tuple[Chunk]:
non_prefetched_chunks = []
for chunk in chunks:
if chunk in self._async_works:
self._async_works[chunk].wait()
del self._async_works[chunk]
else:
non_prefetched_chunks.append(chunk)
return tuple(non_prefetched_chunks)
def add_work(self, chunk: Chunk, work: dist.Work):
assert work is not None
assert chunk not in self._async_works
self._async_works[chunk] = work
@functools.lru_cache(maxsize=None)
def _get_layout_info(self, compute_idx: int, warmup: bool, chunks: Tuple[Chunk, ...]):
start = time()
@ -133,9 +154,9 @@ class GeminiManager:
can_evict_chunks = self._chunk_manager.get_cuda_movable_chunks()
return cuda_demand, can_evict_chunks
def _record_chunks_order(self, chunks: Tuple[Chunk, ...]) -> None:
def _record_warmup_chunks_order(self, chunks: Tuple[Chunk, ...], record_anyway: bool = False) -> None:
self._compute_idx += 1
if self._warmup and self._placement_policy.need_mem_stats:
if self._warmup and (self._placement_policy.need_mem_stats or record_anyway):
self._compute_list.append(chunks)
def sample_overall_data(self):
@ -156,6 +177,22 @@ class GeminiManager:
return self._mem_stats_collector.cuda_margin_mem
return None
@property
def placement_policy(self) -> PlacementPolicy:
return self._placement_policy
@property
def compute_list(self) -> List[Tuple[Chunk, ...]]:
return self._compute_list
@property
def compute_idx(self) -> int:
return self._compute_idx
@property
def async_works(self) -> Dict[Chunk, dist.Work]:
return self._async_works
@property
def is_cuda_margin_mem_avail(self) -> bool:
return self._placement_policy.need_mem_stats

View File

@ -5,6 +5,7 @@ from time import time
from typing import Dict, List, Optional, Tuple, Type
import torch
import torch.distributed as dist
from colossalai.accelerator import get_accelerator
from colossalai.legacy.utils.memory import colo_device_memory_capacity
@ -18,10 +19,15 @@ class PlacementPolicy(ABC):
need_mem_stats: bool = False
def __init__(
self, chunk_manager: ChunkManager, mem_stats_collector: Optional[ChunkMemStatsCollector] = None, **kwargs
self,
chunk_manager: ChunkManager,
mem_stats_collector: Optional[ChunkMemStatsCollector] = None,
max_prefetch: int = 0,
**kwargs,
) -> None:
self.chunk_manager = chunk_manager
self.mem_stats_collector: Optional[ChunkMemStatsCollector] = mem_stats_collector
self.max_prefetch = max_prefetch
@abstractmethod
def evict_tensors(self, can_evict_chunks: List[Chunk], **kwargs) -> Tuple[int, float]:
@ -33,18 +39,24 @@ class PlacementPolicy(ABC):
) -> None:
raise NotImplementedError
def get_prefetch_chunks(
self, is_warmup, compute_list: tuple, compute_idx: int, async_works: Dict[Chunk, dist.Work]
) -> List[Chunk]:
return [] # no prefetch by default
class StaticPlacementPolicy(PlacementPolicy):
def __init__(
self,
chunk_manager: ChunkManager,
mem_stats_collector: Optional[ChunkMemStatsCollector] = None,
max_prefetch: int = 0,
shard_param_frac: float = 1.0,
offload_optim_frac: float = 0.0,
offload_param_frac: float = 0.0,
**kwargs,
) -> None:
super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector)
super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector, max_prefetch=max_prefetch)
if offload_param_frac > 0.0 and (shard_param_frac != 1.0 or offload_optim_frac != 1.0):
warnings.warn("offload_param_frac is ignored when shard_param_frac != 1.0 or offload_optim_frac != 1.0")
offload_param_frac = 0.0
@ -95,6 +107,24 @@ class StaticPlacementPolicy(PlacementPolicy):
self.keep_gathered_chunk_mem = total_chunk_mem * (1 - self.shard_param_frac)
self.keep_cuda_chunk_mem = total_chunk_mem * (1 - self.offload_param_frac)
def get_prefetch_chunks(
self, is_warmup: bool, compute_list: tuple, compute_idx: int, async_works: Dict[Chunk, dist.Work]
) -> List[Chunk]:
if is_warmup: # no prefetch during warmup since we need compute_list
return []
can_prefetch = self.max_prefetch - len(async_works)
prefetch = []
for i in range(compute_idx + 1, len(compute_list)):
for chunk in compute_list[i]:
if len(prefetch) >= can_prefetch:
break
if chunk not in prefetch and chunk not in self.chunk_manager.accessed_chunks:
prefetch.append(chunk)
else:
continue
break
return prefetch
class AutoPlacementPolicy(PlacementPolicy):
need_mem_stats: bool = True
@ -103,17 +133,20 @@ class AutoPlacementPolicy(PlacementPolicy):
self,
chunk_manager: ChunkManager,
mem_stats_collector: Optional[ChunkMemStatsCollector] = None,
max_prefetch: int = 0,
warmup_non_model_data_ratio: float = 0.8,
steady_cuda_cap_ratio: float = 0.9,
**kwargs,
) -> None:
super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector)
super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector, max_prefetch=max_prefetch)
# model data will use 1-_warmup_non_model_data_ratio CUDA memory in warmup phase
# you can set them by AutoPlacementPolicy.set_warmup_non_model_data_ratio()
# and AutoPlacementPolicy.set_steady_cuda_cap_ratio()
self._warmup_non_model_data_ratio = warmup_non_model_data_ratio
self._steady_cuda_cap_ratio = steady_cuda_cap_ratio
self.__avail_cuda_model_data_for_prefetch = None
def evict_tensors(
self,
can_evict_chunks: List[Chunk],
@ -173,6 +206,7 @@ class AutoPlacementPolicy(PlacementPolicy):
f"Adjust layout failed! No enough CUDA memory! "
f"Need {to_free_cuda_model_data}, freed {freed_cuda_model_data}"
)
self.__avail_cuda_model_data_for_prefetch = avail_cuda_model_data + freed_cuda_model_data
return freed_cuda_model_data, time() - start
@staticmethod
@ -198,6 +232,30 @@ class AutoPlacementPolicy(PlacementPolicy):
else:
grads_device_map[p] = torch.device("cpu")
def get_prefetch_chunks(
self, is_warmup: bool, compute_list: tuple, compute_idx: int, async_works: Dict[Chunk, dist.Work]
) -> List[Chunk]:
if is_warmup: # no prefetch during warmup since we need compute_list
return []
avail_cuda_model_data = self.__avail_cuda_model_data_for_prefetch
self.__avail_cuda_model_data_for_prefetch = None # incase of double use
prefetch_chunk_memory = 0
can_prefetch = self.max_prefetch - len(async_works)
prefetch = []
for i in range(compute_idx + 1, len(compute_list)):
for chunk in compute_list[i]:
if len(prefetch) >= can_prefetch or prefetch_chunk_memory + chunk.chunk_mem > avail_cuda_model_data:
break
if chunk not in prefetch and chunk not in self.chunk_manager.accessed_chunks:
prefetch_chunk_memory += chunk.chunk_mem
prefetch.append(chunk)
else:
continue
break
return prefetch
class PlacementPolicyFactory:
policies: Dict[str, Type[PlacementPolicy]] = {

View File

@ -0,0 +1 @@
../../../performance_evaluator.py

View File

@ -1,8 +1,6 @@
import time
from contextlib import nullcontext
import torch
from torch.profiler import ProfilerActivity, profile, schedule, tensorboard_trace_handler
class DummyProfiler:
@ -24,19 +22,6 @@ def get_tflops(model_numel, batch_size, seq_len, step_time):
return model_numel * batch_size * seq_len * 8 / 1e12 / (step_time + 1e-12)
def get_profile_context(enable_flag, warmup_steps, active_steps, save_dir):
if enable_flag:
return profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
schedule=schedule(wait=0, warmup=warmup_steps, active=active_steps),
on_trace_ready=tensorboard_trace_handler(save_dir),
record_shapes=True,
profile_memory=True,
)
else:
return nullcontext(DummyProfiler())
def get_time_stamp():
cur_time = time.strftime("%d-%H:%M", time.localtime())
return cur_time

View File

@ -8,7 +8,8 @@ import psutil
import torch
import torch.nn as nn
from commons.model_zoo import model_builder
from commons.utils import get_data, get_profile_context, get_tflops, get_time_stamp
from commons.performance_evaluator import get_profile_context
from commons.utils import get_data, get_tflops, get_time_stamp
from packaging import version
import colossalai

View File

@ -1,11 +1,12 @@
import argparse
import resource
import time
from contextlib import nullcontext
import torch
from data_utils import RandomDataset
from model_utils import format_numel_str, get_model_numel
from performance_evaluator import PerformanceEvaluator
from performance_evaluator import PerformanceEvaluator, get_profile_context
from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, MixedPrecision
from tqdm import tqdm
from transformers import AutoConfig, AutoModelForCausalLM
@ -76,8 +77,11 @@ def main():
parser.add_argument("--mbs", type=int, default=1, help="Micro batch size of pipeline parallel")
parser.add_argument("--zero", type=int, default=0, help="Zero Stage when hybrid plugin is enabled")
parser.add_argument("--custom-ckpt", action="store_true", help="Customize checkpoint", default=False)
parser.add_argument("--disable-async-reduce", action="store_true", help="Customize checkpoint", default=False)
parser.add_argument("--profile", action="store_true", help="Enable profiling", default=False)
parser.add_argument(
"--disable-async-reduce", action="store_true", help="Disable the asynchronous reduce operation", default=False
)
parser.add_argument("--prefetch_num", type=int, default=0, help="chunk prefetch max number")
args = parser.parse_args()
colossalai.launch_from_torch()
@ -112,6 +116,7 @@ def main():
extra_dp_size=args.extra_dp,
enable_fused_normalization=torch.cuda.is_available(),
enable_flash_attention=args.xformers,
max_prefetch=args.prefetch_num,
enable_async_reduce=not args.disable_async_reduce,
)
elif args.plugin == "gemini_auto":
@ -122,6 +127,8 @@ def main():
tp_size=args.tp,
extra_dp_size=args.extra_dp,
enable_fused_normalization=torch.cuda.is_available(),
max_prefetch=args.prefetch_num,
enable_async_reduce=not args.disable_async_reduce,
enable_flash_attention=args.xformers,
)
elif args.plugin == "fsdp":
@ -249,25 +256,37 @@ def main():
f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss/1024:.2f} MB"
)
if isinstance(plugin, HybridParallelPlugin) and args.pp > 1:
data_iter = iter(dataloader)
for step in tqdm(range(len(dataloader)), desc="Step", disable=not coordinator.is_master()):
performance_evaluator.on_step_start(step)
booster.execute_pipeline(
data_iter, model, criterion=lambda outputs, inputs: outputs[0], optimizer=optimizer, return_loss=False
)
optimizer.step()
optimizer.zero_grad()
performance_evaluator.on_step_end(input_ids=torch.empty(args.batch_size, args.max_length))
else:
for step, batch in enumerate(tqdm(dataloader, desc="Step", disable=not coordinator.is_master())):
performance_evaluator.on_step_start(step)
outputs = model(**batch)
loss = outputs[0]
booster.backward(loss, optimizer)
optimizer.step()
optimizer.zero_grad()
performance_evaluator.on_step_end(**batch)
with get_profile_context(
args.profile,
1,
len(dataloader) - 1,
save_dir=f"profile/{time.strftime('%H:%M', time.localtime())}-{args.plugin}-llama-{args.config}",
) as prof:
if isinstance(plugin, HybridParallelPlugin) and args.pp > 1:
data_iter = iter(dataloader)
for step in tqdm(range(len(dataloader)), desc="Step", disable=not coordinator.is_master()):
performance_evaluator.on_step_start(step)
booster.execute_pipeline(
data_iter,
model,
criterion=lambda outputs, inputs: outputs[0],
optimizer=optimizer,
return_loss=False,
)
optimizer.step()
optimizer.zero_grad()
performance_evaluator.on_step_end(input_ids=torch.empty(args.batch_size, args.max_length))
prof.step()
else:
for step, batch in enumerate(tqdm(dataloader, desc="Step", disable=not coordinator.is_master())):
performance_evaluator.on_step_start(step)
outputs = model(**batch)
loss = outputs[0]
booster.backward(loss, optimizer)
optimizer.step()
optimizer.zero_grad()
performance_evaluator.on_step_end(**batch)
prof.step()
performance_evaluator.on_fit_end()
coordinator.print_on_master(f"Max CUDA memory usage: {get_accelerator().max_memory_allocated()/1024**2:.2f} MB")

View File

@ -4,6 +4,7 @@ from typing import Optional
import torch
import torch.distributed as dist
from torch import Tensor
from torch.profiler import ProfilerActivity, profile, schedule, tensorboard_trace_handler
from colossalai.accelerator import get_accelerator
from colossalai.cluster import DistCoordinator
@ -27,6 +28,33 @@ def all_reduce_mean(x: float, world_size: int) -> float:
return tensor.item()
def get_profile_context(enable_flag, warmup_steps, active_steps, save_dir):
class DummyProfiler:
def __init__(self):
self.step_number = 0
def step(self):
self.step_number += 1
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
pass
if enable_flag:
return profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
schedule=schedule(wait=0, warmup=warmup_steps, active=active_steps),
on_trace_ready=tensorboard_trace_handler(save_dir),
record_shapes=True,
profile_memory=True,
with_stack=True,
)
else:
return DummyProfiler()
class Timer:
def __init__(self) -> None:
self.start_time: Optional[float] = None

View File

@ -40,6 +40,7 @@ def check_grad(model: GeminiDDP, torch_model: torch.nn.Module):
@parameterize("model_name", ["transformers_gpt_lm"])
@parameterize("use_grad_checkpoint", [False, True])
@parameterize("master_weights", [False, True])
@parameterize("max_prefetch", [0, 4])
@parameterize("enable_async_reduce", [False, True])
def exam_gpt_fwd_bwd(
placement_config,
@ -47,6 +48,7 @@ def exam_gpt_fwd_bwd(
model_name: str,
use_grad_checkpoint: bool = False,
master_weights: bool = True,
max_prefetch: int = 0,
enable_async_reduce=True,
):
init_device = get_accelerator().get_current_device()
@ -77,6 +79,7 @@ def exam_gpt_fwd_bwd(
pin_memory=True,
**placement_config,
master_weights=master_weights,
max_prefetch=max_prefetch,
enable_async_reduce=enable_async_reduce,
)
optimizer = HybridAdam(model.parameters(), lr=1e-3)

View File

@ -50,6 +50,7 @@ def check_grad(model: GeminiDDP, torch_model: torch.nn.Module):
@parameterize("model_name", ["transformers_gpt_lm"])
@parameterize("master_weights", [False, True])
@parameterize("use_grad_checkpoint", [False, True])
@parameterize("max_prefetch", [0, 4])
@parameterize("enable_async_reduce", [False, True])
def exam_gemini_grad_acc(
placement_config,
@ -57,6 +58,7 @@ def exam_gemini_grad_acc(
model_name: str,
master_weights: bool,
use_grad_checkpoint: bool,
max_prefetch: int,
enable_async_reduce: bool,
):
init_device = get_accelerator().get_current_device()
@ -87,6 +89,7 @@ def exam_gemini_grad_acc(
pin_memory=True,
enable_gradient_accumulation=True,
master_weights=master_weights,
max_prefetch=max_prefetch,
enable_async_reduce=enable_async_reduce,
**placement_config,
)

View File

@ -52,8 +52,11 @@ def check_param(model: GeminiDDP, torch_model: torch.nn.Module):
@parameterize("placement_config", PLACEMENT_CONFIGS)
@parameterize("model_name", ["transformers_gpt_lm"])
@parameterize("master_weights", [True, False])
@parameterize("max_prefetch", [0, 1, 4])
@parameterize("enable_async_reduce", [False, True])
def exam_grad_clipping(placement_config, model_name: str, master_weights: bool, enable_async_reduce: bool):
def exam_grad_clipping(
placement_config, model_name: str, master_weights: bool, max_prefetch: int, enable_async_reduce: bool
):
set_seed(1912)
model_builder, data_gen_fn, output_transform_fn, loss_fn, *_ = next(
iter(model_zoo.get_sub_registry(model_name).values())
@ -85,6 +88,7 @@ def exam_grad_clipping(placement_config, model_name: str, master_weights: bool,
chunk_init_device=init_device,
pin_memory=True,
master_weights=master_weights,
max_prefetch=max_prefetch,
enable_async_reduce=enable_async_reduce,
**placement_config,
)