import logging import os import torch device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # aitest data MODEL_SAVE_PATH = 'data/aitest/models' ADV_EXAMPLES_SAVE_PATH = 'data/aitest/adv_examples' DATASET_SAVE_PATH = 'data/aitest/datasets' FILES_SAVE_PATH = 'data/aitest/files' # text_attack configuration USE_ENCODING_PATH = "data/aitest/models/tfhub_modules/063d866c06683311b44b4992fd46003be952409c" TEXTATTACK_CACHE_DIR = 'data/aitest/models/textattack' # text robust analysis GPT2_PATH = 'data/aitest/models/gpt-2' NLTK_PATH = 'data/aitest/files/nltk_data' class LogHelper: LOG_STRING = "\033[34;1mAItest\033[0m" def __init__(self, log_path, root_log_name=None): self.log_path = log_path if root_log_name: self.log_name = f"{root_log_name}.Interpretability_Image" self.logger = logging.getLogger(f"{root_log_name}.Interpretability_Image") else: self.log_name = 'aitest_Interpretability_Image' self.logger = logging.getLogger('aitest_Interpretability_Image') formatter = logging.Formatter(f"{LogHelper.LOG_STRING}: %(message)s") stream_handler = logging.StreamHandler('') stream_handler.setFormatter(formatter) self.logger.addHandler(stream_handler) log_formatter = logging.Formatter("%(asctime)s : %(message)s") log_filter = logging.Filter(name=self.log_name) log_handler = logging.FileHandler(self.log_path) log_handler.setLevel(level=logging.INFO) log_handler.setFormatter(log_formatter) log_handler.addFilter(log_filter) self.fh = log_handler def build_new_log(self): self.logger.addHandler(self.fh) return self def insert_log(self, content=None): if content: self.logger.info(content) def finish_log(self): self.logger.removeHandler(self.fh) def get_log(log_path): if os.path.exists(log_path): with open(log_path, 'r+', encoding='utf-8') as f: texts = f.read() f.truncate(0) return texts return None def get_size(path, show_bytes=True): size = 0 if os.path.isfile(path): size = os.path.getsize(path) elif os.path.isdir(path): for root, dirs, files in os.walk(path): for file in files: size += os.path.getsize(os.path.join(root, file)) if show_bytes: return int(size) for x in ['bytes', 'KB', 'MB', 'GB', 'TB']: if size < 1024.0: return "%3.1f %s" % (size, x) size /= 1024.0 return size