image_interprebility/pytorch_grad_cam/utils/model_targets.py

104 lines
3.0 KiB
Python
Raw Normal View History

2023-06-05 15:11:03 +08:00
import numpy as np
import torch
import torchvision
class ClassifierOutputTarget:
def __init__(self, category):
self.category = category
def __call__(self, model_output):
if len(model_output.shape) == 1:
return model_output[self.category]
return model_output[:, self.category]
class ClassifierOutputSoftmaxTarget:
def __init__(self, category):
self.category = category
def __call__(self, model_output):
if len(model_output.shape) == 1:
return torch.softmax(model_output, dim=-1)[self.category]
return torch.softmax(model_output, dim=-1)[:, self.category]
class BinaryClassifierOutputTarget:
def __init__(self, category):
self.category = category
def __call__(self, model_output):
if self.category == 1:
sign = 1
else:
sign = -1
return model_output * sign
class SoftmaxOutputTarget:
def __init__(self):
pass
def __call__(self, model_output):
return torch.softmax(model_output, dim=-1)
class RawScoresOutputTarget:
def __init__(self):
pass
def __call__(self, model_output):
return model_output
class SemanticSegmentationTarget:
""" Gets a binary spatial mask and a category,
And return the sum of the category scores,
of the pixels in the mask. """
def __init__(self, category, mask):
self.category = category
self.mask = torch.from_numpy(mask)
if torch.cuda.is_available():
self.mask = self.mask.cuda()
def __call__(self, model_output):
return (model_output[self.category, :, :] * self.mask).sum()
class FasterRCNNBoxScoreTarget:
""" For every original detected bounding box specified in "bounding boxes",
assign a score on how the current bounding boxes match it,
1. In IOU
2. In the classification score.
If there is not a large enough overlap, or the category changed,
assign a score of 0.
The total score is the sum of all the box scores.
"""
def __init__(self, labels, bounding_boxes, iou_threshold=0.5):
self.labels = labels
self.bounding_boxes = bounding_boxes
self.iou_threshold = iou_threshold
def __call__(self, model_outputs):
output = torch.Tensor([0])
if torch.cuda.is_available():
output = output.cuda()
if len(model_outputs["boxes"]) == 0:
return output
for box, label in zip(self.bounding_boxes, self.labels):
box = torch.Tensor(box[None, :])
if torch.cuda.is_available():
box = box.cuda()
ious = torchvision.ops.box_iou(box, model_outputs["boxes"])
index = ious.argmax()
if ious[0, index] > self.iou_threshold and model_outputs["labels"][index] == label:
score = ious[0, index] + model_outputs["scores"][index]
output = output + score
return output