[NFC] polish colossalai/cli/benchmark/utils.py code style (#4254)

pull/4338/head
ocd_with_naming 2023-07-18 10:54:27 +08:00 committed by binmakeswell
parent dee1c96344
commit 85774f0c1f
1 changed files with 6 additions and 5 deletions

View File

@ -1,10 +1,11 @@
import math
import time
from typing import Callable, Dict, List, Tuple
import torch
from colossalai.context import Config, ParallelMode
from colossalai.utils import MultiTimer
from colossalai.context import ParallelMode, Config
from typing import List, Dict, Tuple, Callable
def get_time_stamp() -> int:
@ -25,8 +26,8 @@ def get_memory_states() -> Tuple[float]:
Return the memory statistics.
Returns:
max_allocated (float): the allocated CUDA memory
max_cached (float): the cached CUDA memory
max_allocated (float): the allocated CUDA memory
max_cached (float): the cached CUDA memory
"""
max_allocated = torch.cuda.max_memory_allocated() / (1024**3)
@ -101,7 +102,7 @@ def profile_model(model: torch.nn.Module, warmup_steps: int, profile_steps: int,
profile_steps (int): the number of steps for profiling
data_func (Callable): a function to generate random data
timer (colossalai.utils.Multitimer): a timer instance for time recording
Returns:
fwd_time (float): the average forward time taken by forward pass in second
bwd_time (float): the average backward time taken by forward pass in second