110 lines
3.4 KiB
Python
110 lines
3.4 KiB
Python
|
import torch
|
||
|
import numpy as np
|
||
|
from typing import List, Callable
|
||
|
|
||
|
import numpy as np
|
||
|
import cv2
|
||
|
|
||
|
|
||
|
class PerturbationConfidenceMetric:
|
||
|
def __init__(self, perturbation):
|
||
|
self.perturbation = perturbation
|
||
|
|
||
|
def __call__(self, input_tensor: torch.Tensor,
|
||
|
cams: np.ndarray,
|
||
|
targets: List[Callable],
|
||
|
model: torch.nn.Module,
|
||
|
return_visualization=False,
|
||
|
return_diff=True):
|
||
|
|
||
|
if return_diff:
|
||
|
with torch.no_grad():
|
||
|
outputs = model(input_tensor)
|
||
|
scores = [target(output).cpu().numpy()
|
||
|
for target, output in zip(targets, outputs)]
|
||
|
scores = np.float32(scores)
|
||
|
|
||
|
batch_size = input_tensor.size(0)
|
||
|
perturbated_tensors = []
|
||
|
for i in range(batch_size):
|
||
|
cam = cams[i]
|
||
|
tensor = self.perturbation(input_tensor[i, ...].cpu(),
|
||
|
torch.from_numpy(cam))
|
||
|
tensor = tensor.to(input_tensor.device)
|
||
|
perturbated_tensors.append(tensor.unsqueeze(0))
|
||
|
perturbated_tensors = torch.cat(perturbated_tensors)
|
||
|
|
||
|
with torch.no_grad():
|
||
|
outputs_after_imputation = model(perturbated_tensors)
|
||
|
scores_after_imputation = [
|
||
|
target(output).cpu().numpy() for target, output in zip(
|
||
|
targets, outputs_after_imputation)]
|
||
|
scores_after_imputation = np.float32(scores_after_imputation)
|
||
|
|
||
|
if return_diff:
|
||
|
result = scores_after_imputation - scores
|
||
|
else:
|
||
|
result = scores_after_imputation
|
||
|
|
||
|
if return_visualization:
|
||
|
return result, perturbated_tensors
|
||
|
else:
|
||
|
return result
|
||
|
|
||
|
|
||
|
class RemoveMostRelevantFirst:
|
||
|
def __init__(self, percentile, imputer):
|
||
|
self.percentile = percentile
|
||
|
self.imputer = imputer
|
||
|
|
||
|
def __call__(self, input_tensor, mask):
|
||
|
imputer = self.imputer
|
||
|
if self.percentile != 'auto':
|
||
|
threshold = np.percentile(mask.cpu().numpy(), self.percentile)
|
||
|
binary_mask = np.float32(mask < threshold)
|
||
|
else:
|
||
|
_, binary_mask = cv2.threshold(
|
||
|
np.uint8(mask * 255), 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
|
||
|
|
||
|
binary_mask = torch.from_numpy(binary_mask)
|
||
|
binary_mask = binary_mask.to(mask.device)
|
||
|
return imputer(input_tensor, binary_mask)
|
||
|
|
||
|
|
||
|
class RemoveLeastRelevantFirst(RemoveMostRelevantFirst):
|
||
|
def __init__(self, percentile, imputer):
|
||
|
super(RemoveLeastRelevantFirst, self).__init__(percentile, imputer)
|
||
|
|
||
|
def __call__(self, input_tensor, mask):
|
||
|
return super(RemoveLeastRelevantFirst, self).__call__(
|
||
|
input_tensor, 1 - mask)
|
||
|
|
||
|
|
||
|
class AveragerAcrossThresholds:
|
||
|
def __init__(
|
||
|
self,
|
||
|
imputer,
|
||
|
percentiles=[
|
||
|
10,
|
||
|
20,
|
||
|
30,
|
||
|
40,
|
||
|
50,
|
||
|
60,
|
||
|
70,
|
||
|
80,
|
||
|
90]):
|
||
|
self.imputer = imputer
|
||
|
self.percentiles = percentiles
|
||
|
|
||
|
def __call__(self,
|
||
|
input_tensor: torch.Tensor,
|
||
|
cams: np.ndarray,
|
||
|
targets: List[Callable],
|
||
|
model: torch.nn.Module):
|
||
|
scores = []
|
||
|
for percentile in self.percentiles:
|
||
|
imputer = self.imputer(percentile)
|
||
|
scores.append(imputer(input_tensor, cams, targets, model))
|
||
|
return np.mean(np.float32(scores), axis=0)
|