239 lines
9.0 KiB
Python
239 lines
9.0 KiB
Python
import argparse
|
||
import cv2
|
||
import numpy as np
|
||
import torch
|
||
import json
|
||
from torchvision import models
|
||
from importlib import import_module
|
||
from utils import LogHelper
|
||
from pytorch_grad_cam import GradCAM, \
|
||
HiResCAM, \
|
||
ScoreCAM, \
|
||
GradCAMPlusPlus, \
|
||
AblationCAM, \
|
||
XGradCAM, \
|
||
EigenCAM, \
|
||
EigenGradCAM, \
|
||
LayerCAM, \
|
||
FullGrad, \
|
||
GradCAMElementWise
|
||
from pytorch_grad_cam.utils.find_layers import find_layer_types_recursive
|
||
from pytorch_grad_cam import GuidedBackpropReLUModel
|
||
from pytorch_grad_cam.utils.image import show_cam_on_image, \
|
||
deprocess_image, \
|
||
preprocess_image
|
||
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
|
||
from PIL import Image
|
||
from torchvision import transforms
|
||
import urllib.request
|
||
import json
|
||
|
||
|
||
# Load labels
|
||
with open("imagenet_1000.json") as f:
|
||
labels = json.load(f)
|
||
|
||
|
||
|
||
available_device = 'cuda'
|
||
def get_args():
|
||
parser = argparse.ArgumentParser()
|
||
parser.add_argument('--use-cuda', action='store_true', default=True,
|
||
help='Use NVIDIA GPU acceleration')
|
||
parser.add_argument('--aug_smooth', action='store_true',
|
||
help='Apply test time augmentation to smooth the CAM')
|
||
parser.add_argument(
|
||
'--eigen_smooth',
|
||
action='store_true',
|
||
help='Reduce noise by taking the first principle componenet'
|
||
'of cam_weights*activations')
|
||
args = parser.parse_args()
|
||
args.use_cuda = args.use_cuda and torch.cuda.is_available()
|
||
if args.use_cuda:
|
||
print('Using GPU for acceleration')
|
||
else:
|
||
print('Using CPU for computation')
|
||
|
||
return args
|
||
|
||
|
||
class ImageInterpretability():
|
||
def __init__(self):
|
||
super().__init__()
|
||
|
||
|
||
def perform(self,image_path: str, method: str, model_info: dict,output_path: str,log_path=None,**kwargs):
|
||
# aug_smooth, eigen_smooth,
|
||
args = get_args()
|
||
'''
|
||
图片输入地址
|
||
:param image_path: (type=str, required=True) value= (图片地址)
|
||
|
||
可解释性算法
|
||
:param method: (type=str, required=True) value=['gradcam', 'hirescam', 'scorecam', 'gradcam++',
|
||
'ablationcam', 'xgradcam', 'eigencam', 'eigengradcam', 'layercam', 'gradcamelementwise','fullgrad'] (可解释性方法)
|
||
|
||
:param model_name: (type=str, required=True) value=[resnet, vgg, densenet, mnasnet] (模型名称)
|
||
|
||
:param model: (type=nn.Module, required=True) value=[resnet18, resnet50, vgg11, vgg13, vgg16, vgg19, densenet161, mnasnet1_0] (模型)
|
||
|
||
kwargs:非必要传入参数,在特定要求下传入
|
||
{
|
||
aug_smooth:默认采用数据增强技术来改善cam质量
|
||
(type=bool) value=[Ture, False]
|
||
|
||
eigen_smooth:计算CAM(类激活映射)权重和激活之间的矩阵乘积,然后提取该结果的第一个主成分来减少结果中的噪音。
|
||
(type=bool) value=[Ture, False]
|
||
}
|
||
|
||
|
||
'''
|
||
methods = \
|
||
{"gradcam": GradCAM,
|
||
"hirescam": HiResCAM,
|
||
"scorecam": ScoreCAM,
|
||
"gradcam++": GradCAMPlusPlus,
|
||
"ablationcam": AblationCAM,
|
||
"xgradcam": XGradCAM,
|
||
"eigencam": EigenCAM,
|
||
"eigengradcam": EigenGradCAM,
|
||
"layercam": LayerCAM,
|
||
"fullgrad": FullGrad,
|
||
"gradcamelementwise": GradCAMElementWise}
|
||
|
||
# model = models.resnet50(pretrained=True)
|
||
# model_class = getattr(models,model_info.model_name)
|
||
# model = model_class(pretrained=True)
|
||
model = self.get_model(info=model_info, device=available_device)
|
||
# attack_log = LogHelper(log_path=log_path, root_log_name='aitest').build_new_log()
|
||
|
||
if 'resnet' in model_info.get('model_name').lower():
|
||
target_layers = [model.layer4]
|
||
elif 'vgg' in model_info.get('model_name').lower():
|
||
target_layers = [model.features[-1]]
|
||
elif 'densenet' in model_info.get('model_name').lower():
|
||
target_layers = [model.features[-1]]
|
||
elif 'mnasnet' in model_info.get('model_name').lower():
|
||
target_layers = [model.layers[-1]]
|
||
else:
|
||
target_layers = find_layer_types_recursive(model, [torch.nn.ReLU])
|
||
|
||
|
||
rgb_img = cv2.imread(image_path, 1)[:, :, ::-1]
|
||
rgb_img = np.float32(rgb_img) / 255
|
||
input_tensor = preprocess_image(rgb_img,
|
||
mean=[0.485, 0.456, 0.406],
|
||
std=[0.229, 0.224, 0.225])
|
||
|
||
|
||
targets = None
|
||
|
||
cam_algorithm = methods[method]
|
||
with cam_algorithm(model=model,
|
||
target_layers=target_layers,
|
||
use_cuda=args.use_cuda) as cam:
|
||
|
||
# AblationCAM and ScoreCAM have batched implementations.
|
||
# You can override the internal batch size for faster computation.
|
||
cam.batch_size = 32
|
||
aug_smooth = kwargs['aug_smooth']
|
||
eigen_smooth = kwargs['eigen_smooth']
|
||
grayscale_cam = cam(input_tensor=input_tensor,
|
||
targets=targets,
|
||
aug_smooth=aug_smooth,
|
||
eigen_smooth=eigen_smooth)
|
||
|
||
# Here grayscale_cam has only one image in the batch
|
||
grayscale_cam = grayscale_cam[0, :]
|
||
|
||
cam_image = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True)
|
||
|
||
# cam_image is RGB encoded whereas "cv2.imwrite" requires BGR encoding.
|
||
cam_image = cv2.cvtColor(cam_image, cv2.COLOR_RGB2BGR)
|
||
|
||
gb_model = GuidedBackpropReLUModel(model=model, use_cuda=args.use_cuda)
|
||
gb = gb_model(input_tensor, target_category=None)
|
||
|
||
cam_mask = cv2.merge([grayscale_cam, grayscale_cam, grayscale_cam])
|
||
cam_gb = deprocess_image(cam_mask * gb)
|
||
gb = deprocess_image(gb)
|
||
|
||
model.eval()
|
||
image = Image.open(image_path)
|
||
transform = transforms.Compose([
|
||
transforms.Resize(256),
|
||
transforms.CenterCrop(224),
|
||
transforms.ToTensor(),
|
||
transforms.Normalize(
|
||
mean=[0.485, 0.456, 0.406],
|
||
std=[0.229, 0.224, 0.225]
|
||
)
|
||
])
|
||
image = transform(image)
|
||
# Perform forward pass
|
||
with torch.no_grad():
|
||
output = model(image.unsqueeze(0))
|
||
|
||
# Convert output to probabilities
|
||
probabilities = torch.softmax(output[0], dim=0)
|
||
# Get predicted label and probability
|
||
predicted_class = torch.argmax(probabilities).item()
|
||
probability = probabilities[predicted_class].item()
|
||
|
||
# Get predicted label name
|
||
predicted_class = torch.argmax(probabilities).item()
|
||
label_name = labels[str(predicted_class)]
|
||
#print(f"Predicted label: {label_name}")
|
||
|
||
#imwrite不支持带有中文路径的地址
|
||
cv2.imwrite(output_path+label_name+'-'+method+"_cam3.jpg", cam_image)
|
||
cv2.imwrite(output_path+label_name+'-'+method+"_gb.jpg", gb)
|
||
cv2.imwrite(output_path+label_name+'-'+method+"_cam_gb.jpg",cam_gb)
|
||
|
||
return {'output_path':output_path,'Predicted_class': label_name, 'Probability': probability}
|
||
|
||
@staticmethod
|
||
def get_model(info, device):
|
||
# if isinstance(info, dict):
|
||
# info = argparse.Namespace(**info)
|
||
#
|
||
# if isinstance(info.path, dict):
|
||
# if isinstance(info.path, str):
|
||
# info.path = json.loads(info.path)
|
||
|
||
# load pytorch model from user upload files
|
||
# if info.ownership != 'aitest' and ('upload' in info.type and 'pytorch' in info.type):
|
||
# return load_pytorch_model(state_dict_path=info.path['parameter_file'], device=device, net_file_path=info.path['structure_file'])
|
||
|
||
# load built_in text model from textattack
|
||
if 'torchvision' == info.get('source'):
|
||
model_class = getattr(models, info.get('model_name'))
|
||
return model_class(pretrained=True)
|
||
|
||
|
||
|
||
def load_pytorch_model(state_dict_path, model_class_name, device, net_file_path=None):
|
||
"""
|
||
model = load_user_model('word_cnn_for_classification',
|
||
'C:\\Users\\pcl\\.cache\\textattack\\models\\classification\\cnn\\rotten-tomatoes\\pytorch_model.bin',
|
||
'WordCNNForClassification')
|
||
"""
|
||
if net_file_path:
|
||
net_file_path = str(net_file_path).replace('\\', '/')
|
||
net_file_path = net_file_path.replace('./', '').replace('/', '.')
|
||
model_class = getattr(import_module(net_file_path), model_class_name)
|
||
model = model_class()
|
||
model.load(state_dict_path, device)
|
||
# tokenizer = model.tokenizer
|
||
# model = textattack.models.wrappers.PyTorchModelWrapper(
|
||
# model, tokenizer
|
||
# )
|
||
return model
|
||
else:
|
||
return torch.load(state_dict_path, device)
|
||
|
||
|
||
|
||
|
||
|