mirror of https://github.com/hpcaitech/ColossalAI
31 lines
928 B
Python
31 lines
928 B
Python
|
import logging
|
||
|
import os
|
||
|
|
||
|
import torch.distributed as dist
|
||
|
|
||
|
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
||
|
datefmt='%m/%d/%Y %H:%M:%S',
|
||
|
level=logging.INFO)
|
||
|
logger = logging.getLogger(__name__)
|
||
|
|
||
|
|
||
|
class Logger():
|
||
|
|
||
|
def __init__(self, log_path, cuda=False, debug=False):
|
||
|
self.logger = logging.getLogger(__name__)
|
||
|
self.cuda = cuda
|
||
|
self.log_path = log_path
|
||
|
self.debug = debug
|
||
|
|
||
|
def info(self, message, log_=True, print_=True, *args, **kwargs):
|
||
|
if (self.cuda and dist.get_rank() == 0) or not self.cuda:
|
||
|
if print_:
|
||
|
self.logger.info(message, *args, **kwargs)
|
||
|
|
||
|
if log_:
|
||
|
with open(self.log_path, 'a+') as f_log:
|
||
|
f_log.write(message + '\n')
|
||
|
|
||
|
def error(self, message, *args, **kwargs):
|
||
|
self.logger.error(message, *args, **kwargs)
|