import numpy as np import torch import tqdm from typing import Callable, List from pytorch_grad_cam.base_cam import BaseCAM from pytorch_grad_cam.utils.find_layers import replace_layer_recursive from pytorch_grad_cam.ablation_layer import AblationLayer """ Implementation of AblationCAM https://openaccess.thecvf.com/content_WACV_2020/papers/Desai_Ablation-CAM_Visual_Explanations_for_Deep_Convolutional_Network_via_Gradient-free_Localization_WACV_2020_paper.pdf Ablate individual activations, and then measure the drop in the target score. In the current implementation, the target layer activations is cached, so it won't be re-computed. However layers before it, if any, will not be cached. This means that if the target layer is a large block, for example model.featuers (in vgg), there will be a large save in run time. Since we have to go over many channels and ablate them, and every channel ablation requires a forward pass, it would be nice if we could avoid doing that for channels that won't contribute anwyay, making it much faster. The parameter ratio_channels_to_ablate controls how many channels should be ablated, using an experimental method (to be improved). The default 1.0 value means that all channels will be ablated. """ class AblationCAM(BaseCAM): def __init__(self, model: torch.nn.Module, target_layers: List[torch.nn.Module], use_cuda: bool = False, reshape_transform: Callable = None, ablation_layer: torch.nn.Module = AblationLayer(), batch_size: int = 32, ratio_channels_to_ablate: float = 1.0) -> None: super(AblationCAM, self).__init__(model, target_layers, use_cuda, reshape_transform, uses_gradients=False) self.batch_size = batch_size self.ablation_layer = ablation_layer self.ratio_channels_to_ablate = ratio_channels_to_ablate def save_activation(self, module, input, output) -> None: """ Helper function to save the raw activations from the target layer """ self.activations = output def assemble_ablation_scores(self, new_scores: list, original_score: float, ablated_channels: np.ndarray, number_of_channels: int) -> np.ndarray: """ Take the value from the channels that were ablated, and just set the original score for the channels that were skipped """ index = 0 result = [] sorted_indices = np.argsort(ablated_channels) ablated_channels = ablated_channels[sorted_indices] new_scores = np.float32(new_scores)[sorted_indices] for i in range(number_of_channels): if index < len(ablated_channels) and ablated_channels[index] == i: weight = new_scores[index] index = index + 1 else: weight = original_score result.append(weight) return result def get_cam_weights(self, input_tensor: torch.Tensor, target_layer: torch.nn.Module, targets: List[Callable], activations: torch.Tensor, grads: torch.Tensor) -> np.ndarray: # Do a forward pass, compute the target scores, and cache the # activations handle = target_layer.register_forward_hook(self.save_activation) with torch.no_grad(): outputs = self.model(input_tensor) handle.remove() original_scores = np.float32( [target(output).cpu().item() for target, output in zip(targets, outputs)]) # Replace the layer with the ablation layer. # When we finish, we will replace it back, so the original model is # unchanged. ablation_layer = self.ablation_layer replace_layer_recursive(self.model, target_layer, ablation_layer) number_of_channels = activations.shape[1] weights = [] # This is a "gradient free" method, so we don't need gradients here. with torch.no_grad(): # Loop over each of the batch images and ablate activations for it. for batch_index, (target, tensor) in enumerate( zip(targets, input_tensor)): new_scores = [] batch_tensor = tensor.repeat(self.batch_size, 1, 1, 1) # Check which channels should be ablated. Normally this will be all channels, # But we can also try to speed this up by using a low # ratio_channels_to_ablate. channels_to_ablate = ablation_layer.activations_to_be_ablated( activations[batch_index, :], self.ratio_channels_to_ablate) number_channels_to_ablate = len(channels_to_ablate) for i in tqdm.tqdm( range( 0, number_channels_to_ablate, self.batch_size)): if i + self.batch_size > number_channels_to_ablate: batch_tensor = batch_tensor[:( number_channels_to_ablate - i)] # Change the state of the ablation layer so it ablates the next channels. # TBD: Move this into the ablation layer forward pass. ablation_layer.set_next_batch( input_batch_index=batch_index, activations=self.activations, num_channels_to_ablate=batch_tensor.size(0)) score = [target(o).cpu().item() for o in self.model(batch_tensor)] new_scores.extend(score) ablation_layer.indices = ablation_layer.indices[batch_tensor.size( 0):] new_scores = self.assemble_ablation_scores( new_scores, original_scores[batch_index], channels_to_ablate, number_of_channels) weights.extend(new_scores) weights = np.float32(weights) weights = weights.reshape(activations.shape[:2]) original_scores = original_scores[:, None] weights = (original_scores - weights) / original_scores # Replace the model back to the original state replace_layer_recursive(self.model, ablation_layer, target_layer) return weights