mirror of https://github.com/hpcaitech/ColossalAI
Merge branch 'prefetch' of github.com:botbw/ColossalAI into feature/prefetch
commit
fc2248cf99
|
@ -567,6 +567,7 @@ class Chunk:
|
||||||
return self is __o
|
return self is __o
|
||||||
|
|
||||||
def __repr__(self, detailed: bool = True):
|
def __repr__(self, detailed: bool = True):
|
||||||
|
return f"Chunk({self.count_id})"
|
||||||
output = [
|
output = [
|
||||||
"Chunk Information:\n",
|
"Chunk Information:\n",
|
||||||
"\tchunk size: {}, chunk dtype: {}, process group size: {}\n".format(
|
"\tchunk size: {}, chunk dtype: {}, process group size: {}\n".format(
|
||||||
|
|
|
@ -131,9 +131,10 @@ class GeminiDDP(ModelWrapper):
|
||||||
offload_param_frac=offload_param_frac,
|
offload_param_frac=offload_param_frac,
|
||||||
warmup_non_model_data_ratio=warmup_non_model_data_ratio,
|
warmup_non_model_data_ratio=warmup_non_model_data_ratio,
|
||||||
steady_cuda_cap_ratio=steady_cuda_cap_ratio,
|
steady_cuda_cap_ratio=steady_cuda_cap_ratio,
|
||||||
|
max_prefetch=max_prefetch,
|
||||||
)
|
)
|
||||||
self.force_outputs_fp32 = force_outputs_fp32
|
self.force_outputs_fp32 = force_outputs_fp32
|
||||||
self.param_op_hook = GeminiZeROHook(self.gemini_manager, max_prefetch=max_prefetch)
|
self.param_op_hook = GeminiZeROHook(self.gemini_manager)
|
||||||
self.fp32_params: List[torch.Tensor] = list()
|
self.fp32_params: List[torch.Tensor] = list()
|
||||||
self.fp16_params: List[ColoParameter] = list()
|
self.fp16_params: List[ColoParameter] = list()
|
||||||
self.grads_device: Dict[torch.Tensor, torch.device] = dict()
|
self.grads_device: Dict[torch.Tensor, torch.device] = dict()
|
||||||
|
|
|
@ -1,10 +1,9 @@
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Dict, List
|
from typing import List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
|
||||||
|
|
||||||
from colossalai.logging import DistributedLogger
|
from colossalai.logging import DistributedLogger
|
||||||
from colossalai.tensor.param_op_hook import ColoParamOpHook
|
from colossalai.tensor.param_op_hook import ColoParamOpHook
|
||||||
|
@ -12,8 +11,6 @@ from colossalai.utils import is_ddp_ignored
|
||||||
from colossalai.zero.gemini import TensorState
|
from colossalai.zero.gemini import TensorState
|
||||||
from colossalai.zero.gemini.gemini_mgr import GeminiManager
|
from colossalai.zero.gemini.gemini_mgr import GeminiManager
|
||||||
|
|
||||||
from .chunk import Chunk
|
|
||||||
|
|
||||||
|
|
||||||
class TrainingPhase(Enum):
|
class TrainingPhase(Enum):
|
||||||
FORWARD = 0
|
FORWARD = 0
|
||||||
|
@ -22,45 +19,60 @@ class TrainingPhase(Enum):
|
||||||
|
|
||||||
logger = DistributedLogger("gemini_hook")
|
logger = DistributedLogger("gemini_hook")
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
rank = int(os.environ["RANK"])
|
||||||
|
|
||||||
|
|
||||||
class GeminiZeROHook(ColoParamOpHook):
|
class GeminiZeROHook(ColoParamOpHook):
|
||||||
def __init__(self, gemini_manager: GeminiManager, max_prefetch: int = 0) -> None:
|
def __init__(self, gemini_manager: GeminiManager) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._gemini_manager = gemini_manager
|
self._gemini_manager = gemini_manager
|
||||||
self._chunk_manager = gemini_manager.chunk_manager
|
self._chunk_manager = gemini_manager.chunk_manager
|
||||||
self._training_phase = TrainingPhase.FORWARD
|
self._training_phase = TrainingPhase.FORWARD
|
||||||
self._max_prefetch = max_prefetch
|
|
||||||
self._async_works: Dict[Chunk, dist.work] = {}
|
|
||||||
|
|
||||||
def wait_chunks(self, chunks: List[Chunk]) -> List[Chunk]:
|
|
||||||
non_prefetched_chunks = []
|
|
||||||
for chunk in chunks:
|
|
||||||
if chunk in self._async_works:
|
|
||||||
print(f"prefetched {chunk.count_id}")
|
|
||||||
self._async_works[chunk].wait()
|
|
||||||
del self._async_works[chunk]
|
|
||||||
else:
|
|
||||||
non_prefetched_chunks.append(chunk)
|
|
||||||
return non_prefetched_chunks
|
|
||||||
|
|
||||||
def pre_op(self, params):
|
def pre_op(self, params):
|
||||||
|
# map params to chunks
|
||||||
params = [p for p in params if not is_ddp_ignored(p)]
|
params = [p for p in params if not is_ddp_ignored(p)]
|
||||||
all_chunks = self._chunk_manager.get_chunks(params)
|
all_chunks = self._chunk_manager.get_chunks(params)
|
||||||
|
|
||||||
# wait for prefetched chunks, filter those are not prefetched
|
# wait for prefetched chunks, filter those are not prefetched
|
||||||
chunks_fetch_sync = tuple(self.wait_chunks(all_chunks))
|
set(all_chunks)
|
||||||
|
chunks_fetch_sync = self._gemini_manager.wait_chunks(all_chunks)
|
||||||
|
|
||||||
|
# transfer state
|
||||||
for p in params:
|
for p in params:
|
||||||
self._chunk_manager.trans_tensor_state(p, TensorState.COMPUTE)
|
self._chunk_manager.trans_tensor_state(p, TensorState.COMPUTE)
|
||||||
self._gemini_manager.sample_overall_data()
|
self._gemini_manager.sample_overall_data()
|
||||||
self._gemini_manager.adjust_layout(all_chunks, record_anyway=self._max_prefetch > 0)
|
|
||||||
# fetch the rest chunks synchronously
|
# evit chunks, aware of async fetched
|
||||||
|
self._gemini_manager.adjust_layout(
|
||||||
|
all_chunks, record_anyway=self._gemini_manager.placement_policy.max_prefetch > 0
|
||||||
|
)
|
||||||
|
|
||||||
|
# fetch the rest synchronously
|
||||||
for chunk in chunks_fetch_sync:
|
for chunk in chunks_fetch_sync:
|
||||||
self._chunk_manager.access_chunk(chunk)
|
self._chunk_manager.access_chunk(chunk)
|
||||||
chunks_fetch_async = self._gemini_manager.placement_policy.get_prefetch_chunks(max_prefetch=self._max_prefetch)
|
|
||||||
|
# get possible chunks to prefetch
|
||||||
|
chunks_fetch_async = self._gemini_manager.placement_policy.get_prefetch_chunks()
|
||||||
|
if rank == 0 and not self._gemini_manager.is_warmup():
|
||||||
|
print(
|
||||||
|
f"compute_id: {self._gemini_manager.compute_idx} self._gemini_manager.compute_list: {self._gemini_manager.compute_list}"
|
||||||
|
)
|
||||||
|
print(f"{all_chunks=}")
|
||||||
|
print(f"accessed_chunks={self._chunk_manager.accessed_chunks}")
|
||||||
|
print(f"{chunks_fetch_sync=}")
|
||||||
|
print(f"{chunks_fetch_async=}")
|
||||||
|
print(f"works={list(self._gemini_manager._async_works.keys())}")
|
||||||
|
|
||||||
|
# prefetch
|
||||||
for chunk in chunks_fetch_async:
|
for chunk in chunks_fetch_async:
|
||||||
maybe_work = self._chunk_manager.access_chunk(chunk, async_access=True)
|
maybe_work = self._chunk_manager.access_chunk(chunk, async_access=True)
|
||||||
if maybe_work is not None:
|
if maybe_work is not None:
|
||||||
self._async_works[chunk] = maybe_work
|
self._gemini_manager.add_work(chunk, maybe_work)
|
||||||
|
if rank == 0 and not self._gemini_manager.is_warmup():
|
||||||
|
print(f"post accessed_chunks={self._chunk_manager.accessed_chunks}")
|
||||||
# record cuda model data of the current OP, including memory for prefetched chunks
|
# record cuda model data of the current OP, including memory for prefetched chunks
|
||||||
self._gemini_manager.record_model_data_volume()
|
self._gemini_manager.record_model_data_volume()
|
||||||
|
|
||||||
|
@ -88,11 +100,6 @@ class GeminiZeROHook(ColoParamOpHook):
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def switch_training_phase(self, training_phase: TrainingPhase = TrainingPhase.BACKWARD):
|
def switch_training_phase(self, training_phase: TrainingPhase = TrainingPhase.BACKWARD):
|
||||||
if training_phase == TrainingPhase.FORWARD:
|
|
||||||
self._cur_param_idx = 0
|
|
||||||
else:
|
|
||||||
self._cur_param_idx = len(self._param_visited_order) - 1
|
|
||||||
|
|
||||||
old_training_phase = self._training_phase
|
old_training_phase = self._training_phase
|
||||||
try:
|
try:
|
||||||
self._training_phase = training_phase
|
self._training_phase = training_phase
|
||||||
|
|
|
@ -1,8 +1,9 @@
|
||||||
import functools
|
import functools
|
||||||
from time import time
|
from time import time
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, Iterable, List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
|
||||||
from .chunk import Chunk, ChunkManager
|
from .chunk import Chunk, ChunkManager
|
||||||
from .memory_tracer import ChunkMemStatsCollector, MemStats
|
from .memory_tracer import ChunkMemStatsCollector, MemStats
|
||||||
|
@ -41,9 +42,10 @@ class GeminiManager:
|
||||||
self._mem_stats_collector = (
|
self._mem_stats_collector = (
|
||||||
ChunkMemStatsCollector(chunk_manager, self._memstats) if policy_cls.need_mem_stats else None
|
ChunkMemStatsCollector(chunk_manager, self._memstats) if policy_cls.need_mem_stats else None
|
||||||
)
|
)
|
||||||
self._placement_policy = policy_cls(chunk_manager, self._mem_stats_collector, **placement_kwargs)
|
self._placement_policy = policy_cls(self, chunk_manager, self._mem_stats_collector, **placement_kwargs)
|
||||||
self._compute_list: List[Tuple[Chunk, ...]] = []
|
self._compute_list: List[Tuple[Chunk, ...]] = []
|
||||||
self._compute_idx: int = -1
|
self._compute_idx: int = -1
|
||||||
|
self._async_works: Dict[Chunk, dist.work] = {}
|
||||||
|
|
||||||
self._h2d_volume = 0
|
self._h2d_volume = 0
|
||||||
self._d2h_volume = 0
|
self._d2h_volume = 0
|
||||||
|
@ -98,11 +100,13 @@ class GeminiManager:
|
||||||
# find stateful tensor in state COMPUTE
|
# find stateful tensor in state COMPUTE
|
||||||
start = time()
|
start = time()
|
||||||
self._record_warmup_chunks_order(chunks, record_anyway=record_anyway)
|
self._record_warmup_chunks_order(chunks, record_anyway=record_anyway)
|
||||||
cuda_demand, hold_cuda_tensor_list = self._get_layout_info(self._compute_idx, self._warmup, chunks)
|
cuda_demand, can_evict_chunks = self._get_layout_info(self._compute_idx, self._warmup, chunks)
|
||||||
|
# don't evict chunks that are asynchronously fetched
|
||||||
|
can_evict_chunks = [chunk for chunk in can_evict_chunks if chunk not in self._async_works]
|
||||||
self._layout_time += time() - start
|
self._layout_time += time() - start
|
||||||
|
|
||||||
vol, evict_time = self._placement_policy.evict_tensors(
|
vol, evict_time = self._placement_policy.evict_tensors(
|
||||||
can_evict_chunks=hold_cuda_tensor_list,
|
can_evict_chunks=can_evict_chunks,
|
||||||
cuda_demand=cuda_demand,
|
cuda_demand=cuda_demand,
|
||||||
warmup=self._warmup,
|
warmup=self._warmup,
|
||||||
compute_list=self._compute_list,
|
compute_list=self._compute_list,
|
||||||
|
@ -114,6 +118,21 @@ class GeminiManager:
|
||||||
# move COMPUTE tensors to CUDA
|
# move COMPUTE tensors to CUDA
|
||||||
self._h2d_volume += cuda_demand
|
self._h2d_volume += cuda_demand
|
||||||
|
|
||||||
|
def wait_chunks(self, chunks: Iterable[Chunk]) -> Tuple[Chunk]:
|
||||||
|
non_prefetched_chunks = []
|
||||||
|
for chunk in chunks:
|
||||||
|
if chunk in self._async_works:
|
||||||
|
self._async_works[chunk].wait()
|
||||||
|
del self._async_works[chunk]
|
||||||
|
else:
|
||||||
|
non_prefetched_chunks.append(chunk)
|
||||||
|
return tuple(non_prefetched_chunks)
|
||||||
|
|
||||||
|
def add_work(self, chunk: Chunk, work: dist.Work):
|
||||||
|
assert work is not None
|
||||||
|
assert chunk not in self._async_works
|
||||||
|
self._async_works[chunk] = work
|
||||||
|
|
||||||
@functools.lru_cache(maxsize=None)
|
@functools.lru_cache(maxsize=None)
|
||||||
def _get_layout_info(self, compute_idx: int, warmup: bool, chunks: Tuple[Chunk, ...]):
|
def _get_layout_info(self, compute_idx: int, warmup: bool, chunks: Tuple[Chunk, ...]):
|
||||||
start = time()
|
start = time()
|
||||||
|
|
|
@ -18,10 +18,17 @@ class PlacementPolicy(ABC):
|
||||||
need_mem_stats: bool = False
|
need_mem_stats: bool = False
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, chunk_manager: ChunkManager, mem_stats_collector: Optional[ChunkMemStatsCollector] = None, **kwargs
|
self,
|
||||||
|
gemini_manager: "GeminiManager",
|
||||||
|
chunk_manager: ChunkManager,
|
||||||
|
mem_stats_collector: Optional[ChunkMemStatsCollector] = None,
|
||||||
|
max_prefetch: int = 0,
|
||||||
|
**kwargs,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
self.gemini_manager = gemini_manager
|
||||||
self.chunk_manager = chunk_manager
|
self.chunk_manager = chunk_manager
|
||||||
self.mem_stats_collector: Optional[ChunkMemStatsCollector] = mem_stats_collector
|
self.mem_stats_collector: Optional[ChunkMemStatsCollector] = mem_stats_collector
|
||||||
|
self.max_prefetch = max_prefetch
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def evict_tensors(self, can_evict_chunks: List[Chunk], **kwargs) -> Tuple[int, float]:
|
def evict_tensors(self, can_evict_chunks: List[Chunk], **kwargs) -> Tuple[int, float]:
|
||||||
|
@ -34,21 +41,30 @@ class PlacementPolicy(ABC):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_prefetch_chunks(self, max_prefetch: int) -> List[Chunk]:
|
def get_prefetch_chunks(self) -> List[Chunk]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
rank = int(os.environ["RANK"])
|
||||||
|
|
||||||
|
|
||||||
class StaticPlacementPolicy(PlacementPolicy):
|
class StaticPlacementPolicy(PlacementPolicy):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
gemini_manager: "GeminiManager",
|
||||||
chunk_manager: ChunkManager,
|
chunk_manager: ChunkManager,
|
||||||
mem_stats_collector: Optional[ChunkMemStatsCollector] = None,
|
mem_stats_collector: Optional[ChunkMemStatsCollector] = None,
|
||||||
|
max_prefetch: int = 0,
|
||||||
shard_param_frac: float = 1.0,
|
shard_param_frac: float = 1.0,
|
||||||
offload_optim_frac: float = 0.0,
|
offload_optim_frac: float = 0.0,
|
||||||
offload_param_frac: float = 0.0,
|
offload_param_frac: float = 0.0,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector)
|
super().__init__(
|
||||||
|
gemini_manager, chunk_manager, mem_stats_collector=mem_stats_collector, max_prefetch=max_prefetch
|
||||||
|
)
|
||||||
if offload_param_frac > 0.0 and (shard_param_frac != 1.0 or offload_optim_frac != 1.0):
|
if offload_param_frac > 0.0 and (shard_param_frac != 1.0 or offload_optim_frac != 1.0):
|
||||||
warnings.warn("offload_param_frac is ignored when shard_param_frac != 1.0 or offload_optim_frac != 1.0")
|
warnings.warn("offload_param_frac is ignored when shard_param_frac != 1.0 or offload_optim_frac != 1.0")
|
||||||
offload_param_frac = 0.0
|
offload_param_frac = 0.0
|
||||||
|
@ -99,15 +115,17 @@ class StaticPlacementPolicy(PlacementPolicy):
|
||||||
self.keep_gathered_chunk_mem = total_chunk_mem * (1 - self.shard_param_frac)
|
self.keep_gathered_chunk_mem = total_chunk_mem * (1 - self.shard_param_frac)
|
||||||
self.keep_cuda_chunk_mem = total_chunk_mem * (1 - self.offload_param_frac)
|
self.keep_cuda_chunk_mem = total_chunk_mem * (1 - self.offload_param_frac)
|
||||||
|
|
||||||
def get_prefetch_chunks(self, max_prefetch: int) -> List[Chunk]:
|
def get_prefetch_chunks(self) -> List[Chunk]:
|
||||||
|
if self.gemini_manager.is_warmup(): # no prefetch during warmup since we need compute_list
|
||||||
|
return []
|
||||||
prefetch = []
|
prefetch = []
|
||||||
for i in range(self.chunk_manager.compute_idx + 1, len(self.chunk_manager.compute_list)):
|
for i in range(self.gemini_manager.compute_idx + 1, len(self.gemini_manager.compute_list)):
|
||||||
for chunk in self.chunk_manager.compute_list[i]:
|
for chunk in self.gemini_manager.compute_list[i]:
|
||||||
if len(prefetch) >= max_prefetch:
|
if len(prefetch) >= self.max_prefetch:
|
||||||
break
|
break
|
||||||
if chunk not in prefetch:
|
if chunk not in prefetch and chunk not in self.chunk_manager.accessed_chunks:
|
||||||
prefetch.append(chunk)
|
prefetch.append(chunk)
|
||||||
if len(prefetch) >= max_prefetch:
|
if len(prefetch) >= self.max_prefetch:
|
||||||
break
|
break
|
||||||
return prefetch
|
return prefetch
|
||||||
|
|
||||||
|
@ -117,13 +135,17 @@ class AutoPlacementPolicy(PlacementPolicy):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
gemini_manager: "GeminiManager",
|
||||||
chunk_manager: ChunkManager,
|
chunk_manager: ChunkManager,
|
||||||
mem_stats_collector: Optional[ChunkMemStatsCollector] = None,
|
mem_stats_collector: Optional[ChunkMemStatsCollector] = None,
|
||||||
|
max_prefetch: int = 0,
|
||||||
warmup_non_model_data_ratio: float = 0.8,
|
warmup_non_model_data_ratio: float = 0.8,
|
||||||
steady_cuda_cap_ratio: float = 0.9,
|
steady_cuda_cap_ratio: float = 0.9,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector)
|
super().__init__(
|
||||||
|
gemini_manager, chunk_manager, mem_stats_collector=mem_stats_collector, max_prefetch=max_prefetch
|
||||||
|
)
|
||||||
# model data will use 1-_warmup_non_model_data_ratio CUDA memory in warmup phase
|
# model data will use 1-_warmup_non_model_data_ratio CUDA memory in warmup phase
|
||||||
# you can set them by AutoPlacementPolicy.set_warmup_non_model_data_ratio()
|
# you can set them by AutoPlacementPolicy.set_warmup_non_model_data_ratio()
|
||||||
# and AutoPlacementPolicy.set_steady_cuda_cap_ratio()
|
# and AutoPlacementPolicy.set_steady_cuda_cap_ratio()
|
||||||
|
|
|
@ -30,8 +30,9 @@ def get_profile_context(enable_flag, warmup_steps, active_steps, save_dir):
|
||||||
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
|
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
|
||||||
schedule=schedule(wait=0, warmup=warmup_steps, active=active_steps),
|
schedule=schedule(wait=0, warmup=warmup_steps, active=active_steps),
|
||||||
on_trace_ready=tensorboard_trace_handler(save_dir),
|
on_trace_ready=tensorboard_trace_handler(save_dir),
|
||||||
record_shapes=True,
|
# record_shapes=True,
|
||||||
profile_memory=True,
|
# profile_memory=True,
|
||||||
|
with_stack=True,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return nullcontext(DummyProfiler())
|
return nullcontext(DummyProfiler())
|
||||||
|
|
|
@ -129,7 +129,7 @@ def main():
|
||||||
WARMUP_STEPS = 1
|
WARMUP_STEPS = 1
|
||||||
assert WARMUP_STEPS < NUM_STEPS, "warmup steps should smaller than the total steps"
|
assert WARMUP_STEPS < NUM_STEPS, "warmup steps should smaller than the total steps"
|
||||||
assert (NUM_STEPS - WARMUP_STEPS) % 2 == 1, "the number of valid steps should be odd to take the median"
|
assert (NUM_STEPS - WARMUP_STEPS) % 2 == 1, "the number of valid steps should be odd to take the median"
|
||||||
PROF_FLAG = False # The flag of profiling, False by default
|
PROF_FLAG = True # The flag of profiling, False by default
|
||||||
|
|
||||||
disable_existing_loggers()
|
disable_existing_loggers()
|
||||||
colossalai.launch_from_torch()
|
colossalai.launch_from_torch()
|
||||||
|
@ -166,7 +166,7 @@ def main():
|
||||||
stage=zero_stage, reduce_bucket_size_in_m=12, overlap_communication=True, verbose=True
|
stage=zero_stage, reduce_bucket_size_in_m=12, overlap_communication=True, verbose=True
|
||||||
)
|
)
|
||||||
elif args.distplan == "CAI_Gemini":
|
elif args.distplan == "CAI_Gemini":
|
||||||
plugin = GeminiPlugin(search_range_m=128, hidden_dim=model.config.n_embd)
|
plugin = GeminiPlugin(search_range_m=128, hidden_dim=model.config.n_embd, max_prefetch=1)
|
||||||
else:
|
else:
|
||||||
raise RuntimeError
|
raise RuntimeError
|
||||||
|
|
||||||
|
@ -248,7 +248,7 @@ def main():
|
||||||
prof.step()
|
prof.step()
|
||||||
|
|
||||||
tflops_list.sort()
|
tflops_list.sort()
|
||||||
median_index = ((NUM_STEPS - WARMUP_STEPS) >> 1) + WARMUP_STEPS
|
median_index = min(((NUM_STEPS - WARMUP_STEPS) >> 1) + WARMUP_STEPS, len(tflops_list) - 1)
|
||||||
logger.info(f"Median TFLOPS is {tflops_list[median_index]:.3f}")
|
logger.info(f"Median TFLOPS is {tflops_list[median_index]:.3f}")
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue