239 lines
9.0 KiB
Python
239 lines
9.0 KiB
Python
|
import argparse
|
|||
|
import cv2
|
|||
|
import numpy as np
|
|||
|
import torch
|
|||
|
import json
|
|||
|
from torchvision import models
|
|||
|
from importlib import import_module
|
|||
|
from utils import LogHelper
|
|||
|
from pytorch_grad_cam import GradCAM, \
|
|||
|
HiResCAM, \
|
|||
|
ScoreCAM, \
|
|||
|
GradCAMPlusPlus, \
|
|||
|
AblationCAM, \
|
|||
|
XGradCAM, \
|
|||
|
EigenCAM, \
|
|||
|
EigenGradCAM, \
|
|||
|
LayerCAM, \
|
|||
|
FullGrad, \
|
|||
|
GradCAMElementWise
|
|||
|
from pytorch_grad_cam.utils.find_layers import find_layer_types_recursive
|
|||
|
from pytorch_grad_cam import GuidedBackpropReLUModel
|
|||
|
from pytorch_grad_cam.utils.image import show_cam_on_image, \
|
|||
|
deprocess_image, \
|
|||
|
preprocess_image
|
|||
|
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
|
|||
|
from PIL import Image
|
|||
|
from torchvision import transforms
|
|||
|
import urllib.request
|
|||
|
import json
|
|||
|
|
|||
|
|
|||
|
# Load labels
|
|||
|
with open("imagenet_1000.json") as f:
|
|||
|
labels = json.load(f)
|
|||
|
|
|||
|
|
|||
|
|
|||
|
available_device = 'cuda'
|
|||
|
def get_args():
|
|||
|
parser = argparse.ArgumentParser()
|
|||
|
parser.add_argument('--use-cuda', action='store_true', default=True,
|
|||
|
help='Use NVIDIA GPU acceleration')
|
|||
|
parser.add_argument('--aug_smooth', action='store_true',
|
|||
|
help='Apply test time augmentation to smooth the CAM')
|
|||
|
parser.add_argument(
|
|||
|
'--eigen_smooth',
|
|||
|
action='store_true',
|
|||
|
help='Reduce noise by taking the first principle componenet'
|
|||
|
'of cam_weights*activations')
|
|||
|
args = parser.parse_args()
|
|||
|
args.use_cuda = args.use_cuda and torch.cuda.is_available()
|
|||
|
if args.use_cuda:
|
|||
|
print('Using GPU for acceleration')
|
|||
|
else:
|
|||
|
print('Using CPU for computation')
|
|||
|
|
|||
|
return args
|
|||
|
|
|||
|
|
|||
|
class ImageInterpretability():
|
|||
|
def __init__(self):
|
|||
|
super().__init__()
|
|||
|
|
|||
|
|
|||
|
def perform(self,image_path: str, method: str, model_info: dict,output_path: str,log_path=None,**kwargs):
|
|||
|
# aug_smooth, eigen_smooth,
|
|||
|
args = get_args()
|
|||
|
'''
|
|||
|
图片输入地址
|
|||
|
:param image_path: (type=str, required=True) value= (图片地址)
|
|||
|
|
|||
|
可解释性算法
|
|||
|
:param method: (type=str, required=True) value=['gradcam', 'hirescam', 'scorecam', 'gradcam++',
|
|||
|
'ablationcam', 'xgradcam', 'eigencam', 'eigengradcam', 'layercam', 'gradcamelementwise','fullgrad'] (可解释性方法)
|
|||
|
|
|||
|
:param model_name: (type=str, required=True) value=[resnet, vgg, densenet, mnasnet] (模型名称)
|
|||
|
|
|||
|
:param model: (type=nn.Module, required=True) value=[resnet18, resnet50, vgg11, vgg13, vgg16, vgg19, densenet161, mnasnet1_0] (模型)
|
|||
|
|
|||
|
kwargs:非必要传入参数,在特定要求下传入
|
|||
|
{
|
|||
|
aug_smooth:默认采用数据增强技术来改善cam质量
|
|||
|
(type=bool) value=[Ture, False]
|
|||
|
|
|||
|
eigen_smooth:计算CAM(类激活映射)权重和激活之间的矩阵乘积,然后提取该结果的第一个主成分来减少结果中的噪音。
|
|||
|
(type=bool) value=[Ture, False]
|
|||
|
}
|
|||
|
|
|||
|
|
|||
|
'''
|
|||
|
methods = \
|
|||
|
{"gradcam": GradCAM,
|
|||
|
"hirescam": HiResCAM,
|
|||
|
"scorecam": ScoreCAM,
|
|||
|
"gradcam++": GradCAMPlusPlus,
|
|||
|
"ablationcam": AblationCAM,
|
|||
|
"xgradcam": XGradCAM,
|
|||
|
"eigencam": EigenCAM,
|
|||
|
"eigengradcam": EigenGradCAM,
|
|||
|
"layercam": LayerCAM,
|
|||
|
"fullgrad": FullGrad,
|
|||
|
"gradcamelementwise": GradCAMElementWise}
|
|||
|
|
|||
|
# model = models.resnet50(pretrained=True)
|
|||
|
# model_class = getattr(models,model_info.model_name)
|
|||
|
# model = model_class(pretrained=True)
|
|||
|
model = self.get_model(info=model_info, device=available_device)
|
|||
|
# attack_log = LogHelper(log_path=log_path, root_log_name='aitest').build_new_log()
|
|||
|
|
|||
|
if 'resnet' in model_info.get('model_name').lower():
|
|||
|
target_layers = [model.layer4]
|
|||
|
elif 'vgg' in model_info.get('model_name').lower():
|
|||
|
target_layers = [model.features[-1]]
|
|||
|
elif 'densenet' in model_info.get('model_name').lower():
|
|||
|
target_layers = [model.features[-1]]
|
|||
|
elif 'mnasnet' in model_info.get('model_name').lower():
|
|||
|
target_layers = [model.layers[-1]]
|
|||
|
else:
|
|||
|
target_layers = find_layer_types_recursive(model, [torch.nn.ReLU])
|
|||
|
|
|||
|
|
|||
|
rgb_img = cv2.imread(image_path, 1)[:, :, ::-1]
|
|||
|
rgb_img = np.float32(rgb_img) / 255
|
|||
|
input_tensor = preprocess_image(rgb_img,
|
|||
|
mean=[0.485, 0.456, 0.406],
|
|||
|
std=[0.229, 0.224, 0.225])
|
|||
|
|
|||
|
|
|||
|
targets = None
|
|||
|
|
|||
|
cam_algorithm = methods[method]
|
|||
|
with cam_algorithm(model=model,
|
|||
|
target_layers=target_layers,
|
|||
|
use_cuda=args.use_cuda) as cam:
|
|||
|
|
|||
|
# AblationCAM and ScoreCAM have batched implementations.
|
|||
|
# You can override the internal batch size for faster computation.
|
|||
|
cam.batch_size = 32
|
|||
|
aug_smooth = kwargs['aug_smooth']
|
|||
|
eigen_smooth = kwargs['eigen_smooth']
|
|||
|
grayscale_cam = cam(input_tensor=input_tensor,
|
|||
|
targets=targets,
|
|||
|
aug_smooth=aug_smooth,
|
|||
|
eigen_smooth=eigen_smooth)
|
|||
|
|
|||
|
# Here grayscale_cam has only one image in the batch
|
|||
|
grayscale_cam = grayscale_cam[0, :]
|
|||
|
|
|||
|
cam_image = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True)
|
|||
|
|
|||
|
# cam_image is RGB encoded whereas "cv2.imwrite" requires BGR encoding.
|
|||
|
cam_image = cv2.cvtColor(cam_image, cv2.COLOR_RGB2BGR)
|
|||
|
|
|||
|
gb_model = GuidedBackpropReLUModel(model=model, use_cuda=args.use_cuda)
|
|||
|
gb = gb_model(input_tensor, target_category=None)
|
|||
|
|
|||
|
cam_mask = cv2.merge([grayscale_cam, grayscale_cam, grayscale_cam])
|
|||
|
cam_gb = deprocess_image(cam_mask * gb)
|
|||
|
gb = deprocess_image(gb)
|
|||
|
|
|||
|
model.eval()
|
|||
|
image = Image.open(image_path)
|
|||
|
transform = transforms.Compose([
|
|||
|
transforms.Resize(256),
|
|||
|
transforms.CenterCrop(224),
|
|||
|
transforms.ToTensor(),
|
|||
|
transforms.Normalize(
|
|||
|
mean=[0.485, 0.456, 0.406],
|
|||
|
std=[0.229, 0.224, 0.225]
|
|||
|
)
|
|||
|
])
|
|||
|
image = transform(image)
|
|||
|
# Perform forward pass
|
|||
|
with torch.no_grad():
|
|||
|
output = model(image.unsqueeze(0))
|
|||
|
|
|||
|
# Convert output to probabilities
|
|||
|
probabilities = torch.softmax(output[0], dim=0)
|
|||
|
# Get predicted label and probability
|
|||
|
predicted_class = torch.argmax(probabilities).item()
|
|||
|
probability = probabilities[predicted_class].item()
|
|||
|
|
|||
|
# Get predicted label name
|
|||
|
predicted_class = torch.argmax(probabilities).item()
|
|||
|
label_name = labels[str(predicted_class)]
|
|||
|
#print(f"Predicted label: {label_name}")
|
|||
|
|
|||
|
#imwrite不支持带有中文路径的地址
|
|||
|
cv2.imwrite(output_path+label_name+'-'+method+"_cam3.jpg", cam_image)
|
|||
|
cv2.imwrite(output_path+label_name+'-'+method+"_gb.jpg", gb)
|
|||
|
cv2.imwrite(output_path+label_name+'-'+method+"_cam_gb.jpg",cam_gb)
|
|||
|
|
|||
|
return {'output_path':output_path,'Predicted_class': label_name, 'Probability': probability}
|
|||
|
|
|||
|
@staticmethod
|
|||
|
def get_model(info, device):
|
|||
|
# if isinstance(info, dict):
|
|||
|
# info = argparse.Namespace(**info)
|
|||
|
#
|
|||
|
# if isinstance(info.path, dict):
|
|||
|
# if isinstance(info.path, str):
|
|||
|
# info.path = json.loads(info.path)
|
|||
|
|
|||
|
# load pytorch model from user upload files
|
|||
|
# if info.ownership != 'aitest' and ('upload' in info.type and 'pytorch' in info.type):
|
|||
|
# return load_pytorch_model(state_dict_path=info.path['parameter_file'], device=device, net_file_path=info.path['structure_file'])
|
|||
|
|
|||
|
# load built_in text model from textattack
|
|||
|
if 'torchvision' == info.get('source'):
|
|||
|
model_class = getattr(models, info.get('model_name'))
|
|||
|
return model_class(pretrained=True)
|
|||
|
|
|||
|
|
|||
|
|
|||
|
def load_pytorch_model(state_dict_path, model_class_name, device, net_file_path=None):
|
|||
|
"""
|
|||
|
model = load_user_model('word_cnn_for_classification',
|
|||
|
'C:\\Users\\pcl\\.cache\\textattack\\models\\classification\\cnn\\rotten-tomatoes\\pytorch_model.bin',
|
|||
|
'WordCNNForClassification')
|
|||
|
"""
|
|||
|
if net_file_path:
|
|||
|
net_file_path = str(net_file_path).replace('\\', '/')
|
|||
|
net_file_path = net_file_path.replace('./', '').replace('/', '.')
|
|||
|
model_class = getattr(import_module(net_file_path), model_class_name)
|
|||
|
model = model_class()
|
|||
|
model.load(state_dict_path, device)
|
|||
|
# tokenizer = model.tokenizer
|
|||
|
# model = textattack.models.wrappers.PyTorchModelWrapper(
|
|||
|
# model, tokenizer
|
|||
|
# )
|
|||
|
return model
|
|||
|
else:
|
|||
|
return torch.load(state_dict_path, device)
|
|||
|
|
|||
|
|
|||
|
|
|||
|
|
|||
|
|