image_interprebility/pytorch_grad_cam/ablation_cam.py

149 lines
6.6 KiB
Python
Raw Permalink Normal View History

2023-06-05 15:11:03 +08:00
import numpy as np
import torch
import tqdm
from typing import Callable, List
from pytorch_grad_cam.base_cam import BaseCAM
from pytorch_grad_cam.utils.find_layers import replace_layer_recursive
from pytorch_grad_cam.ablation_layer import AblationLayer
""" Implementation of AblationCAM
https://openaccess.thecvf.com/content_WACV_2020/papers/Desai_Ablation-CAM_Visual_Explanations_for_Deep_Convolutional_Network_via_Gradient-free_Localization_WACV_2020_paper.pdf
Ablate individual activations, and then measure the drop in the target score.
In the current implementation, the target layer activations is cached, so it won't be re-computed.
However layers before it, if any, will not be cached.
This means that if the target layer is a large block, for example model.featuers (in vgg), there will
be a large save in run time.
Since we have to go over many channels and ablate them, and every channel ablation requires a forward pass,
it would be nice if we could avoid doing that for channels that won't contribute anwyay, making it much faster.
The parameter ratio_channels_to_ablate controls how many channels should be ablated, using an experimental method
(to be improved). The default 1.0 value means that all channels will be ablated.
"""
class AblationCAM(BaseCAM):
def __init__(self,
model: torch.nn.Module,
target_layers: List[torch.nn.Module],
use_cuda: bool = False,
reshape_transform: Callable = None,
ablation_layer: torch.nn.Module = AblationLayer(),
batch_size: int = 32,
ratio_channels_to_ablate: float = 1.0) -> None:
super(AblationCAM, self).__init__(model,
target_layers,
use_cuda,
reshape_transform,
uses_gradients=False)
self.batch_size = batch_size
self.ablation_layer = ablation_layer
self.ratio_channels_to_ablate = ratio_channels_to_ablate
def save_activation(self, module, input, output) -> None:
""" Helper function to save the raw activations from the target layer """
self.activations = output
def assemble_ablation_scores(self,
new_scores: list,
original_score: float,
ablated_channels: np.ndarray,
number_of_channels: int) -> np.ndarray:
""" Take the value from the channels that were ablated,
and just set the original score for the channels that were skipped """
index = 0
result = []
sorted_indices = np.argsort(ablated_channels)
ablated_channels = ablated_channels[sorted_indices]
new_scores = np.float32(new_scores)[sorted_indices]
for i in range(number_of_channels):
if index < len(ablated_channels) and ablated_channels[index] == i:
weight = new_scores[index]
index = index + 1
else:
weight = original_score
result.append(weight)
return result
def get_cam_weights(self,
input_tensor: torch.Tensor,
target_layer: torch.nn.Module,
targets: List[Callable],
activations: torch.Tensor,
grads: torch.Tensor) -> np.ndarray:
# Do a forward pass, compute the target scores, and cache the
# activations
handle = target_layer.register_forward_hook(self.save_activation)
with torch.no_grad():
outputs = self.model(input_tensor)
handle.remove()
original_scores = np.float32(
[target(output).cpu().item() for target, output in zip(targets, outputs)])
# Replace the layer with the ablation layer.
# When we finish, we will replace it back, so the original model is
# unchanged.
ablation_layer = self.ablation_layer
replace_layer_recursive(self.model, target_layer, ablation_layer)
number_of_channels = activations.shape[1]
weights = []
# This is a "gradient free" method, so we don't need gradients here.
with torch.no_grad():
# Loop over each of the batch images and ablate activations for it.
for batch_index, (target, tensor) in enumerate(
zip(targets, input_tensor)):
new_scores = []
batch_tensor = tensor.repeat(self.batch_size, 1, 1, 1)
# Check which channels should be ablated. Normally this will be all channels,
# But we can also try to speed this up by using a low
# ratio_channels_to_ablate.
channels_to_ablate = ablation_layer.activations_to_be_ablated(
activations[batch_index, :], self.ratio_channels_to_ablate)
number_channels_to_ablate = len(channels_to_ablate)
for i in tqdm.tqdm(
range(
0,
number_channels_to_ablate,
self.batch_size)):
if i + self.batch_size > number_channels_to_ablate:
batch_tensor = batch_tensor[:(
number_channels_to_ablate - i)]
# Change the state of the ablation layer so it ablates the next channels.
# TBD: Move this into the ablation layer forward pass.
ablation_layer.set_next_batch(
input_batch_index=batch_index,
activations=self.activations,
num_channels_to_ablate=batch_tensor.size(0))
score = [target(o).cpu().item()
for o in self.model(batch_tensor)]
new_scores.extend(score)
ablation_layer.indices = ablation_layer.indices[batch_tensor.size(
0):]
new_scores = self.assemble_ablation_scores(
new_scores,
original_scores[batch_index],
channels_to_ablate,
number_of_channels)
weights.extend(new_scores)
weights = np.float32(weights)
weights = weights.reshape(activations.shape[:2])
original_scores = original_scores[:, None]
weights = (original_scores - weights) / original_scores
# Replace the model back to the original state
replace_layer_recursive(self.model, ablation_layer, target_layer)
return weights