image_interprebility/pytorch_grad_cam/utils/reshape_transforms.py

35 lines
1010 B
Python
Raw Normal View History

2023-06-05 15:11:03 +08:00
import torch
def fasterrcnn_reshape_transform(x):
target_size = x['pool'].size()[-2:]
activations = []
for key, value in x.items():
activations.append(
torch.nn.functional.interpolate(
torch.abs(value),
target_size,
mode='bilinear'))
activations = torch.cat(activations, axis=1)
return activations
def swinT_reshape_transform(tensor, height=7, width=7):
result = tensor.reshape(tensor.size(0),
height, width, tensor.size(2))
# Bring the channels to the first dimension,
# like in CNNs.
result = result.transpose(2, 3).transpose(1, 2)
return result
def vit_reshape_transform(tensor, height=14, width=14):
result = tensor[:, 1:, :].reshape(tensor.size(0),
height, width, tensor.size(2))
# Bring the channels to the first dimension,
# like in CNNs.
result = result.transpose(2, 3).transpose(1, 2)
return result