mirror of https://github.com/hpcaitech/ColossalAI
[NFC] polish colossalai/cli/benchmark/utils.py code style (#4254)
parent
dee1c96344
commit
85774f0c1f
|
@ -1,10 +1,11 @@
|
|||
import math
|
||||
import time
|
||||
from typing import Callable, Dict, List, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from colossalai.context import Config, ParallelMode
|
||||
from colossalai.utils import MultiTimer
|
||||
from colossalai.context import ParallelMode, Config
|
||||
from typing import List, Dict, Tuple, Callable
|
||||
|
||||
|
||||
def get_time_stamp() -> int:
|
||||
|
@ -25,8 +26,8 @@ def get_memory_states() -> Tuple[float]:
|
|||
Return the memory statistics.
|
||||
|
||||
Returns:
|
||||
max_allocated (float): the allocated CUDA memory
|
||||
max_cached (float): the cached CUDA memory
|
||||
max_allocated (float): the allocated CUDA memory
|
||||
max_cached (float): the cached CUDA memory
|
||||
"""
|
||||
|
||||
max_allocated = torch.cuda.max_memory_allocated() / (1024**3)
|
||||
|
@ -101,7 +102,7 @@ def profile_model(model: torch.nn.Module, warmup_steps: int, profile_steps: int,
|
|||
profile_steps (int): the number of steps for profiling
|
||||
data_func (Callable): a function to generate random data
|
||||
timer (colossalai.utils.Multitimer): a timer instance for time recording
|
||||
|
||||
|
||||
Returns:
|
||||
fwd_time (float): the average forward time taken by forward pass in second
|
||||
bwd_time (float): the average backward time taken by forward pass in second
|
||||
|
|
Loading…
Reference in New Issue