image_interprebility/pytorch_grad_cam/metrics/perturbation_confidence.py

110 lines
3.4 KiB
Python
Raw Normal View History

2023-06-05 15:11:03 +08:00
import torch
import numpy as np
from typing import List, Callable
import numpy as np
import cv2
class PerturbationConfidenceMetric:
def __init__(self, perturbation):
self.perturbation = perturbation
def __call__(self, input_tensor: torch.Tensor,
cams: np.ndarray,
targets: List[Callable],
model: torch.nn.Module,
return_visualization=False,
return_diff=True):
if return_diff:
with torch.no_grad():
outputs = model(input_tensor)
scores = [target(output).cpu().numpy()
for target, output in zip(targets, outputs)]
scores = np.float32(scores)
batch_size = input_tensor.size(0)
perturbated_tensors = []
for i in range(batch_size):
cam = cams[i]
tensor = self.perturbation(input_tensor[i, ...].cpu(),
torch.from_numpy(cam))
tensor = tensor.to(input_tensor.device)
perturbated_tensors.append(tensor.unsqueeze(0))
perturbated_tensors = torch.cat(perturbated_tensors)
with torch.no_grad():
outputs_after_imputation = model(perturbated_tensors)
scores_after_imputation = [
target(output).cpu().numpy() for target, output in zip(
targets, outputs_after_imputation)]
scores_after_imputation = np.float32(scores_after_imputation)
if return_diff:
result = scores_after_imputation - scores
else:
result = scores_after_imputation
if return_visualization:
return result, perturbated_tensors
else:
return result
class RemoveMostRelevantFirst:
def __init__(self, percentile, imputer):
self.percentile = percentile
self.imputer = imputer
def __call__(self, input_tensor, mask):
imputer = self.imputer
if self.percentile != 'auto':
threshold = np.percentile(mask.cpu().numpy(), self.percentile)
binary_mask = np.float32(mask < threshold)
else:
_, binary_mask = cv2.threshold(
np.uint8(mask * 255), 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
binary_mask = torch.from_numpy(binary_mask)
binary_mask = binary_mask.to(mask.device)
return imputer(input_tensor, binary_mask)
class RemoveLeastRelevantFirst(RemoveMostRelevantFirst):
def __init__(self, percentile, imputer):
super(RemoveLeastRelevantFirst, self).__init__(percentile, imputer)
def __call__(self, input_tensor, mask):
return super(RemoveLeastRelevantFirst, self).__call__(
input_tensor, 1 - mask)
class AveragerAcrossThresholds:
def __init__(
self,
imputer,
percentiles=[
10,
20,
30,
40,
50,
60,
70,
80,
90]):
self.imputer = imputer
self.percentiles = percentiles
def __call__(self,
input_tensor: torch.Tensor,
cams: np.ndarray,
targets: List[Callable],
model: torch.nn.Module):
scores = []
for percentile in self.percentiles:
imputer = self.imputer(percentile)
scores.append(imputer(input_tensor, cams, targets, model))
return np.mean(np.float32(scores), axis=0)