image_interprebility/task/image_interpretability.py
2023-06-05 15:11:03 +08:00

239 lines
9.0 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import argparse
import cv2
import numpy as np
import torch
import json
from torchvision import models
from importlib import import_module
from utils import LogHelper
from pytorch_grad_cam import GradCAM, \
HiResCAM, \
ScoreCAM, \
GradCAMPlusPlus, \
AblationCAM, \
XGradCAM, \
EigenCAM, \
EigenGradCAM, \
LayerCAM, \
FullGrad, \
GradCAMElementWise
from pytorch_grad_cam.utils.find_layers import find_layer_types_recursive
from pytorch_grad_cam import GuidedBackpropReLUModel
from pytorch_grad_cam.utils.image import show_cam_on_image, \
deprocess_image, \
preprocess_image
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from PIL import Image
from torchvision import transforms
import urllib.request
import json
# Load labels
with open("imagenet_1000.json") as f:
labels = json.load(f)
available_device = 'cuda'
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--use-cuda', action='store_true', default=True,
help='Use NVIDIA GPU acceleration')
parser.add_argument('--aug_smooth', action='store_true',
help='Apply test time augmentation to smooth the CAM')
parser.add_argument(
'--eigen_smooth',
action='store_true',
help='Reduce noise by taking the first principle componenet'
'of cam_weights*activations')
args = parser.parse_args()
args.use_cuda = args.use_cuda and torch.cuda.is_available()
if args.use_cuda:
print('Using GPU for acceleration')
else:
print('Using CPU for computation')
return args
class ImageInterpretability():
def __init__(self):
super().__init__()
def perform(self,image_path: str, method: str, model_info: dict,output_path: str,log_path=None,**kwargs):
# aug_smooth, eigen_smooth,
args = get_args()
'''
图片输入地址
:param image_path: (type=str, required=True) value= (图片地址)
可解释性算法
:param method: (type=str, required=True) value=['gradcam', 'hirescam', 'scorecam', 'gradcam++',
'ablationcam', 'xgradcam', 'eigencam', 'eigengradcam', 'layercam', 'gradcamelementwise','fullgrad'] (可解释性方法)
:param model_name: (type=str, required=True) value=[resnet, vgg, densenet, mnasnet] (模型名称)
:param model: (type=nn.Module, required=True) value=[resnet18, resnet50, vgg11, vgg13, vgg16, vgg19, densenet161, mnasnet1_0] (模型)
kwargs非必要传入参数在特定要求下传入
{
aug_smooth:默认采用数据增强技术来改善cam质量
(type=bool) value=[Ture, False]
eigen_smooth计算CAM类激活映射权重和激活之间的矩阵乘积然后提取该结果的第一个主成分来减少结果中的噪音。
(type=bool) value=[Ture, False]
}
'''
methods = \
{"gradcam": GradCAM,
"hirescam": HiResCAM,
"scorecam": ScoreCAM,
"gradcam++": GradCAMPlusPlus,
"ablationcam": AblationCAM,
"xgradcam": XGradCAM,
"eigencam": EigenCAM,
"eigengradcam": EigenGradCAM,
"layercam": LayerCAM,
"fullgrad": FullGrad,
"gradcamelementwise": GradCAMElementWise}
# model = models.resnet50(pretrained=True)
# model_class = getattr(models,model_info.model_name)
# model = model_class(pretrained=True)
model = self.get_model(info=model_info, device=available_device)
# attack_log = LogHelper(log_path=log_path, root_log_name='aitest').build_new_log()
if 'resnet' in model_info.get('model_name').lower():
target_layers = [model.layer4]
elif 'vgg' in model_info.get('model_name').lower():
target_layers = [model.features[-1]]
elif 'densenet' in model_info.get('model_name').lower():
target_layers = [model.features[-1]]
elif 'mnasnet' in model_info.get('model_name').lower():
target_layers = [model.layers[-1]]
else:
target_layers = find_layer_types_recursive(model, [torch.nn.ReLU])
rgb_img = cv2.imread(image_path, 1)[:, :, ::-1]
rgb_img = np.float32(rgb_img) / 255
input_tensor = preprocess_image(rgb_img,
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
targets = None
cam_algorithm = methods[method]
with cam_algorithm(model=model,
target_layers=target_layers,
use_cuda=args.use_cuda) as cam:
# AblationCAM and ScoreCAM have batched implementations.
# You can override the internal batch size for faster computation.
cam.batch_size = 32
aug_smooth = kwargs['aug_smooth']
eigen_smooth = kwargs['eigen_smooth']
grayscale_cam = cam(input_tensor=input_tensor,
targets=targets,
aug_smooth=aug_smooth,
eigen_smooth=eigen_smooth)
# Here grayscale_cam has only one image in the batch
grayscale_cam = grayscale_cam[0, :]
cam_image = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True)
# cam_image is RGB encoded whereas "cv2.imwrite" requires BGR encoding.
cam_image = cv2.cvtColor(cam_image, cv2.COLOR_RGB2BGR)
gb_model = GuidedBackpropReLUModel(model=model, use_cuda=args.use_cuda)
gb = gb_model(input_tensor, target_category=None)
cam_mask = cv2.merge([grayscale_cam, grayscale_cam, grayscale_cam])
cam_gb = deprocess_image(cam_mask * gb)
gb = deprocess_image(gb)
model.eval()
image = Image.open(image_path)
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])
image = transform(image)
# Perform forward pass
with torch.no_grad():
output = model(image.unsqueeze(0))
# Convert output to probabilities
probabilities = torch.softmax(output[0], dim=0)
# Get predicted label and probability
predicted_class = torch.argmax(probabilities).item()
probability = probabilities[predicted_class].item()
# Get predicted label name
predicted_class = torch.argmax(probabilities).item()
label_name = labels[str(predicted_class)]
#print(f"Predicted label: {label_name}")
#imwrite不支持带有中文路径的地址
cv2.imwrite(output_path+label_name+'-'+method+"_cam3.jpg", cam_image)
cv2.imwrite(output_path+label_name+'-'+method+"_gb.jpg", gb)
cv2.imwrite(output_path+label_name+'-'+method+"_cam_gb.jpg",cam_gb)
return {'output_path':output_path,'Predicted_class': label_name, 'Probability': probability}
@staticmethod
def get_model(info, device):
# if isinstance(info, dict):
# info = argparse.Namespace(**info)
#
# if isinstance(info.path, dict):
# if isinstance(info.path, str):
# info.path = json.loads(info.path)
# load pytorch model from user upload files
# if info.ownership != 'aitest' and ('upload' in info.type and 'pytorch' in info.type):
# return load_pytorch_model(state_dict_path=info.path['parameter_file'], device=device, net_file_path=info.path['structure_file'])
# load built_in text model from textattack
if 'torchvision' == info.get('source'):
model_class = getattr(models, info.get('model_name'))
return model_class(pretrained=True)
def load_pytorch_model(state_dict_path, model_class_name, device, net_file_path=None):
"""
model = load_user_model('word_cnn_for_classification',
'C:\\Users\\pcl\\.cache\\textattack\\models\\classification\\cnn\\rotten-tomatoes\\pytorch_model.bin',
'WordCNNForClassification')
"""
if net_file_path:
net_file_path = str(net_file_path).replace('\\', '/')
net_file_path = net_file_path.replace('./', '').replace('/', '.')
model_class = getattr(import_module(net_file_path), model_class_name)
model = model_class()
model.load(state_dict_path, device)
# tokenizer = model.tokenizer
# model = textattack.models.wrappers.PyTorchModelWrapper(
# model, tokenizer
# )
return model
else:
return torch.load(state_dict_path, device)