mirror of https://github.com/InternLM/InternLM
				
				
				
			
		
			
				
	
	
		
			112 lines
		
	
	
		
			3.2 KiB
		
	
	
	
		
			Python
		
	
	
			
		
		
	
	
			112 lines
		
	
	
		
			3.2 KiB
		
	
	
	
		
			Python
		
	
	
#!/usr/bin/env python
 | 
						|
# -*- encoding: utf-8 -*-
 | 
						|
 | 
						|
import time
 | 
						|
 | 
						|
import torch
 | 
						|
 | 
						|
 | 
						|
class _Timer:
 | 
						|
    """Timer."""
 | 
						|
 | 
						|
    def __init__(self, name):
 | 
						|
        self.name_ = name
 | 
						|
        self.elapsed_ = 0.0
 | 
						|
        self.started_ = False
 | 
						|
        self.start_time = time.time()
 | 
						|
 | 
						|
    def start(self):
 | 
						|
        """Start the timer."""
 | 
						|
        assert not self.started_, "timer has already been started"
 | 
						|
        torch.cuda.synchronize()
 | 
						|
        self.start_time = time.time()
 | 
						|
        self.started_ = True
 | 
						|
 | 
						|
    def stop(self):
 | 
						|
        """Stop the timer."""
 | 
						|
        assert self.started_, "timer is not started"
 | 
						|
        torch.cuda.synchronize()
 | 
						|
        self.elapsed_ += time.time() - self.start_time
 | 
						|
        self.started_ = False
 | 
						|
 | 
						|
    def reset(self):
 | 
						|
        """Reset timer."""
 | 
						|
        self.elapsed_ = 0.0
 | 
						|
        self.started_ = False
 | 
						|
 | 
						|
    def elapsed(self, reset=True):
 | 
						|
        """Calculate the elapsed time."""
 | 
						|
        started_ = self.started_
 | 
						|
        # If the timing in progress, end it first.
 | 
						|
        if self.started_:
 | 
						|
            self.stop()
 | 
						|
        # Get the elapsed time.
 | 
						|
        elapsed_ = self.elapsed_
 | 
						|
        # Reset the elapsed time
 | 
						|
        if reset:
 | 
						|
            self.reset()
 | 
						|
        # If timing was in progress, set it back.
 | 
						|
        if started_:
 | 
						|
            self.start()
 | 
						|
        return elapsed_
 | 
						|
 | 
						|
 | 
						|
class Timers:
 | 
						|
    """Group of timers."""
 | 
						|
 | 
						|
    def __init__(self):
 | 
						|
        self.timers = {}
 | 
						|
 | 
						|
    def __call__(self, name):
 | 
						|
        if name not in self.timers:
 | 
						|
            self.timers[name] = _Timer(name)
 | 
						|
        return self.timers[name]
 | 
						|
 | 
						|
    def write(self, names, writer, iteration, normalizer=1.0, reset=False):
 | 
						|
        """Write timers to a tensorboard writer"""
 | 
						|
        # currently when using add_scalars,
 | 
						|
        # torch.utils.add_scalars makes each timer its own run, which
 | 
						|
        # polutes the runs list, so we just add each as a scalar
 | 
						|
        assert normalizer > 0.0
 | 
						|
        for name in names:
 | 
						|
            if name in self.timers:
 | 
						|
                value = self.timers[name].elapsed(reset=reset) / normalizer
 | 
						|
                writer.add_scalar(f"time/{name}-time", value, iteration)
 | 
						|
 | 
						|
    def log(self, names, logger, normalizer=1.0, reset=True):
 | 
						|
        """Log a group of timers."""
 | 
						|
        assert normalizer > 0.0
 | 
						|
        string = ""
 | 
						|
        for name in names:
 | 
						|
            if name in self.timers:
 | 
						|
                elapsed_time = self.timers[name].elapsed(reset=reset) * 1000.0 / normalizer
 | 
						|
                string += " | {}: {:.2f}".format(name, elapsed_time)
 | 
						|
        if not len(string):  # pylint: disable=C1802
 | 
						|
            return
 | 
						|
        string = "time (ms)" + string
 | 
						|
 | 
						|
        logger.info(string)
 | 
						|
        return string
 | 
						|
 | 
						|
    def debug(self, names, logger, normalizer=1.0, reset=True):
 | 
						|
        """Log a group of timers."""
 | 
						|
        assert normalizer > 0.0
 | 
						|
        string = ""
 | 
						|
        for name in names:
 | 
						|
            if name in self.timers:
 | 
						|
                elapsed_time = self.timers[name].elapsed(reset=reset) * 1000.0 / normalizer
 | 
						|
                string += " | {}: {:.2f}".format(name, elapsed_time)
 | 
						|
        if not len(string):  # pylint: disable=C1802
 | 
						|
            return
 | 
						|
        string = "time (ms)" + string
 | 
						|
 | 
						|
        logger.debug(string)
 | 
						|
        return string
 | 
						|
 | 
						|
    def reset(self):
 | 
						|
        for _, t in self.timers.items():
 | 
						|
            t.reset()
 | 
						|
 | 
						|
 | 
						|
megatron_timer = Timers()
 |