import cv2 import numpy as np import torch import tqdm from pytorch_grad_cam.base_cam import BaseCAM class AblationLayer(torch.nn.Module): def __init__(self, layer, reshape_transform, indices): super(AblationLayer, self).__init__() self.layer = layer self.reshape_transform = reshape_transform # The channels to zero out: self.indices = indices def forward(self, x): self.__call__(x) def __call__(self, x): output = self.layer(x) # Hack to work with ViT, # Since the activation channels are last and not first like in CNNs # Probably should remove it? if self.reshape_transform is not None: output = output.transpose(1, 2) for i in range(output.size(0)): # Commonly the minimum activation will be 0, # And then it makes sense to zero it out. # However depending on the architecture, # If the values can be negative, we use very negative values # to perform the ablation, deviating from the paper. if torch.min(output) == 0: output[i, self.indices[i], :] = 0 else: ABLATION_VALUE = 1e5 output[i, self.indices[i], :] = torch.min( output) - ABLATION_VALUE if self.reshape_transform is not None: output = output.transpose(2, 1) return output def replace_layer_recursive(model, old_layer, new_layer): for name, layer in model._modules.items(): if layer == old_layer: model._modules[name] = new_layer return True elif replace_layer_recursive(layer, old_layer, new_layer): return True return False class AblationCAM(BaseCAM): def __init__(self, model, target_layers, use_cuda=False, reshape_transform=None): super(AblationCAM, self).__init__(model, target_layers, use_cuda, reshape_transform) if len(target_layers) > 1: print( "Warning. You are usign Ablation CAM with more than 1 layers. " "This is supported only if all layers have the same output shape") def set_ablation_layers(self): self.ablation_layers = [] for target_layer in self.target_layers: ablation_layer = AblationLayer(target_layer, self.reshape_transform, indices=[]) self.ablation_layers.append(ablation_layer) replace_layer_recursive(self.model, target_layer, ablation_layer) def unset_ablation_layers(self): # replace the model back to the original state for ablation_layer, target_layer in zip( self.ablation_layers, self.target_layers): replace_layer_recursive(self.model, ablation_layer, target_layer) def set_ablation_layer_batch_indices(self, indices): for ablation_layer in self.ablation_layers: ablation_layer.indices = indices def trim_ablation_layer_batch_indices(self, keep): for ablation_layer in self.ablation_layers: ablation_layer.indices = ablation_layer.indices[:keep] def get_cam_weights(self, input_tensor, target_category, activations, grads): with torch.no_grad(): outputs = self.model(input_tensor).cpu().numpy() original_scores = [] for i in range(input_tensor.size(0)): original_scores.append(outputs[i, target_category[i]]) original_scores = np.float32(original_scores) self.set_ablation_layers() if hasattr(self, "batch_size"): BATCH_SIZE = self.batch_size else: BATCH_SIZE = 32 number_of_channels = activations.shape[1] weights = [] with torch.no_grad(): # Iterate over the input batch for tensor, category in zip(input_tensor, target_category): batch_tensor = tensor.repeat(BATCH_SIZE, 1, 1, 1) for i in tqdm.tqdm(range(0, number_of_channels, BATCH_SIZE)): self.set_ablation_layer_batch_indices( list(range(i, i + BATCH_SIZE))) if i + BATCH_SIZE > number_of_channels: keep = number_of_channels - i batch_tensor = batch_tensor[:keep] self.trim_ablation_layer_batch_indices(self, keep) score = self.model(batch_tensor)[:, category].cpu().numpy() weights.extend(score) weights = np.float32(weights) weights = weights.reshape(activations.shape[:2]) original_scores = original_scores[:, None] weights = (original_scores - weights) / original_scores # replace the model back to the original state self.unset_ablation_layers() return weights