import torch def fasterrcnn_reshape_transform(x): target_size = x['pool'].size()[-2:] activations = [] for key, value in x.items(): activations.append( torch.nn.functional.interpolate( torch.abs(value), target_size, mode='bilinear')) activations = torch.cat(activations, axis=1) return activations def swinT_reshape_transform(tensor, height=7, width=7): result = tensor.reshape(tensor.size(0), height, width, tensor.size(2)) # Bring the channels to the first dimension, # like in CNNs. result = result.transpose(2, 3).transpose(1, 2) return result def vit_reshape_transform(tensor, height=14, width=14): result = tensor[:, 1:, :].reshape(tensor.size(0), height, width, tensor.size(2)) # Bring the channels to the first dimension, # like in CNNs. result = result.transpose(2, 3).transpose(1, 2) return result