image_interprebility/utils.py

89 lines
2.6 KiB
Python
Raw Normal View History

2023-06-05 15:11:03 +08:00
import logging
import os
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# aitest data
MODEL_SAVE_PATH = 'data/aitest/models'
ADV_EXAMPLES_SAVE_PATH = 'data/aitest/adv_examples'
DATASET_SAVE_PATH = 'data/aitest/datasets'
FILES_SAVE_PATH = 'data/aitest/files'
# text_attack configuration
USE_ENCODING_PATH = "data/aitest/models/tfhub_modules/063d866c06683311b44b4992fd46003be952409c"
TEXTATTACK_CACHE_DIR = 'data/aitest/models/textattack'
# text robust analysis
GPT2_PATH = 'data/aitest/models/gpt-2'
NLTK_PATH = 'data/aitest/files/nltk_data'
class LogHelper:
LOG_STRING = "\033[34;1mAItest\033[0m"
def __init__(self, log_path, root_log_name=None):
self.log_path = log_path
if root_log_name:
self.log_name = f"{root_log_name}.Interpretability_Image"
self.logger = logging.getLogger(f"{root_log_name}.Interpretability_Image")
else:
self.log_name = 'aitest_Interpretability_Image'
self.logger = logging.getLogger('aitest_Interpretability_Image')
formatter = logging.Formatter(f"{LogHelper.LOG_STRING}: %(message)s")
stream_handler = logging.StreamHandler('')
stream_handler.setFormatter(formatter)
self.logger.addHandler(stream_handler)
log_formatter = logging.Formatter("%(asctime)s : %(message)s")
log_filter = logging.Filter(name=self.log_name)
log_handler = logging.FileHandler(self.log_path)
log_handler.setLevel(level=logging.INFO)
log_handler.setFormatter(log_formatter)
log_handler.addFilter(log_filter)
self.fh = log_handler
def build_new_log(self):
self.logger.addHandler(self.fh)
return self
def insert_log(self, content=None):
if content:
self.logger.info(content)
def finish_log(self):
self.logger.removeHandler(self.fh)
def get_log(log_path):
if os.path.exists(log_path):
with open(log_path, 'r+', encoding='utf-8') as f:
texts = f.read()
f.truncate(0)
return texts
return None
def get_size(path, show_bytes=True):
size = 0
if os.path.isfile(path):
size = os.path.getsize(path)
elif os.path.isdir(path):
for root, dirs, files in os.walk(path):
for file in files:
size += os.path.getsize(os.path.join(root, file))
if show_bytes:
return int(size)
for x in ['bytes', 'KB', 'MB', 'GB', 'TB']:
if size < 1024.0:
return "%3.1f %s" % (size, x)
size /= 1024.0
return size