89 lines
2.6 KiB
Python
89 lines
2.6 KiB
Python
import logging
|
|
import os
|
|
import torch
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
# aitest data
|
|
MODEL_SAVE_PATH = 'data/aitest/models'
|
|
ADV_EXAMPLES_SAVE_PATH = 'data/aitest/adv_examples'
|
|
DATASET_SAVE_PATH = 'data/aitest/datasets'
|
|
FILES_SAVE_PATH = 'data/aitest/files'
|
|
|
|
|
|
# text_attack configuration
|
|
USE_ENCODING_PATH = "data/aitest/models/tfhub_modules/063d866c06683311b44b4992fd46003be952409c"
|
|
TEXTATTACK_CACHE_DIR = 'data/aitest/models/textattack'
|
|
|
|
# text robust analysis
|
|
GPT2_PATH = 'data/aitest/models/gpt-2'
|
|
NLTK_PATH = 'data/aitest/files/nltk_data'
|
|
|
|
|
|
class LogHelper:
|
|
|
|
LOG_STRING = "\033[34;1mAItest\033[0m"
|
|
|
|
def __init__(self, log_path, root_log_name=None):
|
|
self.log_path = log_path
|
|
|
|
if root_log_name:
|
|
self.log_name = f"{root_log_name}.Interpretability_Image"
|
|
self.logger = logging.getLogger(f"{root_log_name}.Interpretability_Image")
|
|
else:
|
|
self.log_name = 'aitest_Interpretability_Image'
|
|
self.logger = logging.getLogger('aitest_Interpretability_Image')
|
|
formatter = logging.Formatter(f"{LogHelper.LOG_STRING}: %(message)s")
|
|
stream_handler = logging.StreamHandler('')
|
|
stream_handler.setFormatter(formatter)
|
|
self.logger.addHandler(stream_handler)
|
|
|
|
log_formatter = logging.Formatter("%(asctime)s : %(message)s")
|
|
log_filter = logging.Filter(name=self.log_name)
|
|
log_handler = logging.FileHandler(self.log_path)
|
|
log_handler.setLevel(level=logging.INFO)
|
|
log_handler.setFormatter(log_formatter)
|
|
log_handler.addFilter(log_filter)
|
|
self.fh = log_handler
|
|
|
|
def build_new_log(self):
|
|
self.logger.addHandler(self.fh)
|
|
return self
|
|
|
|
def insert_log(self, content=None):
|
|
if content:
|
|
self.logger.info(content)
|
|
|
|
def finish_log(self):
|
|
self.logger.removeHandler(self.fh)
|
|
|
|
|
|
def get_log(log_path):
|
|
if os.path.exists(log_path):
|
|
with open(log_path, 'r+', encoding='utf-8') as f:
|
|
texts = f.read()
|
|
f.truncate(0)
|
|
return texts
|
|
return None
|
|
|
|
|
|
def get_size(path, show_bytes=True):
|
|
size = 0
|
|
if os.path.isfile(path):
|
|
size = os.path.getsize(path)
|
|
elif os.path.isdir(path):
|
|
|
|
for root, dirs, files in os.walk(path):
|
|
for file in files:
|
|
size += os.path.getsize(os.path.join(root, file))
|
|
if show_bytes:
|
|
return int(size)
|
|
|
|
for x in ['bytes', 'KB', 'MB', 'GB', 'TB']:
|
|
if size < 1024.0:
|
|
return "%3.1f %s" % (size, x)
|
|
size /= 1024.0
|
|
return size
|
|
|
|
|