35 lines
1010 B
Python
35 lines
1010 B
Python
|
import torch
|
||
|
|
||
|
|
||
|
def fasterrcnn_reshape_transform(x):
|
||
|
target_size = x['pool'].size()[-2:]
|
||
|
activations = []
|
||
|
for key, value in x.items():
|
||
|
activations.append(
|
||
|
torch.nn.functional.interpolate(
|
||
|
torch.abs(value),
|
||
|
target_size,
|
||
|
mode='bilinear'))
|
||
|
activations = torch.cat(activations, axis=1)
|
||
|
return activations
|
||
|
|
||
|
|
||
|
def swinT_reshape_transform(tensor, height=7, width=7):
|
||
|
result = tensor.reshape(tensor.size(0),
|
||
|
height, width, tensor.size(2))
|
||
|
|
||
|
# Bring the channels to the first dimension,
|
||
|
# like in CNNs.
|
||
|
result = result.transpose(2, 3).transpose(1, 2)
|
||
|
return result
|
||
|
|
||
|
|
||
|
def vit_reshape_transform(tensor, height=14, width=14):
|
||
|
result = tensor[:, 1:, :].reshape(tensor.size(0),
|
||
|
height, width, tensor.size(2))
|
||
|
|
||
|
# Bring the channels to the first dimension,
|
||
|
# like in CNNs.
|
||
|
result = result.transpose(2, 3).transpose(1, 2)
|
||
|
return result
|