47 lines
1.7 KiB
Python
47 lines
1.7 KiB
Python
class ActivationsAndGradients:
|
|
""" Class for extracting activations and
|
|
registering gradients from targetted intermediate layers """
|
|
|
|
def __init__(self, model, target_layers, reshape_transform):
|
|
self.model = model
|
|
self.gradients = []
|
|
self.activations = []
|
|
self.reshape_transform = reshape_transform
|
|
self.handles = []
|
|
for target_layer in target_layers:
|
|
self.handles.append(
|
|
target_layer.register_forward_hook(self.save_activation))
|
|
# Because of https://github.com/pytorch/pytorch/issues/61519,
|
|
# we don't use backward hook to record gradients.
|
|
self.handles.append(
|
|
target_layer.register_forward_hook(self.save_gradient))
|
|
|
|
def save_activation(self, module, input, output):
|
|
activation = output
|
|
|
|
if self.reshape_transform is not None:
|
|
activation = self.reshape_transform(activation)
|
|
self.activations.append(activation.cpu().detach())
|
|
|
|
def save_gradient(self, module, input, output):
|
|
if not hasattr(output, "requires_grad") or not output.requires_grad:
|
|
# You can only register hooks on tensor requires grad.
|
|
return
|
|
|
|
# Gradients are computed in reverse order
|
|
def _store_grad(grad):
|
|
if self.reshape_transform is not None:
|
|
grad = self.reshape_transform(grad)
|
|
self.gradients = [grad.cpu().detach()] + self.gradients
|
|
|
|
output.register_hook(_store_grad)
|
|
|
|
def __call__(self, x):
|
|
self.gradients = []
|
|
self.activations = []
|
|
return self.model(x)
|
|
|
|
def release(self):
|
|
for handle in self.handles:
|
|
handle.remove()
|