184 lines
6.9 KiB
Python
184 lines
6.9 KiB
Python
|
import matplotlib
|
||
|
from matplotlib import pyplot as plt
|
||
|
from matplotlib.lines import Line2D
|
||
|
import cv2
|
||
|
import numpy as np
|
||
|
import torch
|
||
|
from torchvision.transforms import Compose, Normalize, ToTensor
|
||
|
from typing import List, Dict
|
||
|
import math
|
||
|
|
||
|
|
||
|
def preprocess_image(
|
||
|
img: np.ndarray, mean=[
|
||
|
0.5, 0.5, 0.5], std=[
|
||
|
0.5, 0.5, 0.5]) -> torch.Tensor:
|
||
|
preprocessing = Compose([
|
||
|
ToTensor(),
|
||
|
Normalize(mean=mean, std=std)
|
||
|
])
|
||
|
return preprocessing(img.copy()).unsqueeze(0)
|
||
|
|
||
|
|
||
|
def deprocess_image(img):
|
||
|
""" see https://github.com/jacobgil/keras-grad-cam/blob/master/grad-cam.py#L65 """
|
||
|
img = img - np.mean(img)
|
||
|
img = img / (np.std(img) + 1e-5)
|
||
|
img = img * 0.1
|
||
|
img = img + 0.5
|
||
|
img = np.clip(img, 0, 1)
|
||
|
return np.uint8(img * 255)
|
||
|
|
||
|
|
||
|
def show_cam_on_image(img: np.ndarray,
|
||
|
mask: np.ndarray,
|
||
|
use_rgb: bool = False,
|
||
|
colormap: int = cv2.COLORMAP_JET,
|
||
|
image_weight: float = 0.5) -> np.ndarray:
|
||
|
""" This function overlays the cam mask on the image as an heatmap.
|
||
|
By default the heatmap is in BGR format.
|
||
|
|
||
|
:param img: The base image in RGB or BGR format.
|
||
|
:param mask: The cam mask.
|
||
|
:param use_rgb: Whether to use an RGB or BGR heatmap, this should be set to True if 'img' is in RGB format.
|
||
|
:param colormap: The OpenCV colormap to be used.
|
||
|
:param image_weight: The final result is image_weight * img + (1-image_weight) * mask.
|
||
|
:returns: The default image with the cam overlay.
|
||
|
"""
|
||
|
heatmap = cv2.applyColorMap(np.uint8(255 * mask), colormap)
|
||
|
if use_rgb:
|
||
|
heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
|
||
|
heatmap = np.float32(heatmap) / 255
|
||
|
|
||
|
if np.max(img) > 1:
|
||
|
raise Exception(
|
||
|
"The input image should np.float32 in the range [0, 1]")
|
||
|
|
||
|
if image_weight < 0 or image_weight > 1:
|
||
|
raise Exception(
|
||
|
f"image_weight should be in the range [0, 1].\
|
||
|
Got: {image_weight}")
|
||
|
|
||
|
cam = (1 - image_weight) * heatmap + image_weight * img
|
||
|
cam = cam / np.max(cam)
|
||
|
return np.uint8(255 * cam)
|
||
|
|
||
|
|
||
|
def create_labels_legend(concept_scores: np.ndarray,
|
||
|
labels: Dict[int, str],
|
||
|
top_k=2):
|
||
|
concept_categories = np.argsort(concept_scores, axis=1)[:, ::-1][:, :top_k]
|
||
|
concept_labels_topk = []
|
||
|
for concept_index in range(concept_categories.shape[0]):
|
||
|
categories = concept_categories[concept_index, :]
|
||
|
concept_labels = []
|
||
|
for category in categories:
|
||
|
score = concept_scores[concept_index, category]
|
||
|
label = f"{','.join(labels[category].split(',')[:3])}:{score:.2f}"
|
||
|
concept_labels.append(label)
|
||
|
concept_labels_topk.append("\n".join(concept_labels))
|
||
|
return concept_labels_topk
|
||
|
|
||
|
|
||
|
def show_factorization_on_image(img: np.ndarray,
|
||
|
explanations: np.ndarray,
|
||
|
colors: List[np.ndarray] = None,
|
||
|
image_weight: float = 0.5,
|
||
|
concept_labels: List = None) -> np.ndarray:
|
||
|
""" Color code the different component heatmaps on top of the image.
|
||
|
Every component color code will be magnified according to the heatmap itensity
|
||
|
(by modifying the V channel in the HSV color space),
|
||
|
and optionally create a lagend that shows the labels.
|
||
|
|
||
|
Since different factorization component heatmaps can overlap in principle,
|
||
|
we need a strategy to decide how to deal with the overlaps.
|
||
|
This keeps the component that has a higher value in it's heatmap.
|
||
|
|
||
|
:param img: The base image RGB format.
|
||
|
:param explanations: A tensor of shape num_componetns x height x width, with the component visualizations.
|
||
|
:param colors: List of R, G, B colors to be used for the components.
|
||
|
If None, will use the gist_rainbow cmap as a default.
|
||
|
:param image_weight: The final result is image_weight * img + (1-image_weight) * visualization.
|
||
|
:concept_labels: A list of strings for every component. If this is paseed, a legend that shows
|
||
|
the labels and their colors will be added to the image.
|
||
|
:returns: The visualized image.
|
||
|
"""
|
||
|
n_components = explanations.shape[0]
|
||
|
if colors is None:
|
||
|
# taken from https://github.com/edocollins/DFF/blob/master/utils.py
|
||
|
_cmap = plt.cm.get_cmap('gist_rainbow')
|
||
|
colors = [
|
||
|
np.array(
|
||
|
_cmap(i)) for i in np.arange(
|
||
|
0,
|
||
|
1,
|
||
|
1.0 /
|
||
|
n_components)]
|
||
|
concept_per_pixel = explanations.argmax(axis=0)
|
||
|
masks = []
|
||
|
for i in range(n_components):
|
||
|
mask = np.zeros(shape=(img.shape[0], img.shape[1], 3))
|
||
|
mask[:, :, :] = colors[i][:3]
|
||
|
explanation = explanations[i]
|
||
|
explanation[concept_per_pixel != i] = 0
|
||
|
mask = np.uint8(mask * 255)
|
||
|
mask = cv2.cvtColor(mask, cv2.COLOR_RGB2HSV)
|
||
|
mask[:, :, 2] = np.uint8(255 * explanation)
|
||
|
mask = cv2.cvtColor(mask, cv2.COLOR_HSV2RGB)
|
||
|
mask = np.float32(mask) / 255
|
||
|
masks.append(mask)
|
||
|
|
||
|
mask = np.sum(np.float32(masks), axis=0)
|
||
|
result = img * image_weight + mask * (1 - image_weight)
|
||
|
result = np.uint8(result * 255)
|
||
|
|
||
|
if concept_labels is not None:
|
||
|
px = 1 / plt.rcParams['figure.dpi'] # pixel in inches
|
||
|
fig = plt.figure(figsize=(result.shape[1] * px, result.shape[0] * px))
|
||
|
plt.rcParams['legend.fontsize'] = int(
|
||
|
14 * result.shape[0] / 256 / max(1, n_components / 6))
|
||
|
lw = 5 * result.shape[0] / 256
|
||
|
lines = [Line2D([0], [0], color=colors[i], lw=lw)
|
||
|
for i in range(n_components)]
|
||
|
plt.legend(lines,
|
||
|
concept_labels,
|
||
|
mode="expand",
|
||
|
fancybox=True,
|
||
|
shadow=True)
|
||
|
|
||
|
plt.tight_layout(pad=0, w_pad=0, h_pad=0)
|
||
|
plt.axis('off')
|
||
|
fig.canvas.draw()
|
||
|
data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
|
||
|
plt.close(fig=fig)
|
||
|
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
|
||
|
data = cv2.resize(data, (result.shape[1], result.shape[0]))
|
||
|
result = np.hstack((result, data))
|
||
|
return result
|
||
|
|
||
|
|
||
|
def scale_cam_image(cam, target_size=None):
|
||
|
result = []
|
||
|
for img in cam:
|
||
|
img = img - np.min(img)
|
||
|
img = img / (1e-7 + np.max(img))
|
||
|
if target_size is not None:
|
||
|
img = cv2.resize(img, target_size)
|
||
|
result.append(img)
|
||
|
result = np.float32(result)
|
||
|
|
||
|
return result
|
||
|
|
||
|
|
||
|
def scale_accross_batch_and_channels(tensor, target_size):
|
||
|
batch_size, channel_size = tensor.shape[:2]
|
||
|
reshaped_tensor = tensor.reshape(
|
||
|
batch_size * channel_size, *tensor.shape[2:])
|
||
|
result = scale_cam_image(reshaped_tensor, target_size)
|
||
|
result = result.reshape(
|
||
|
batch_size,
|
||
|
channel_size,
|
||
|
target_size[1],
|
||
|
target_size[0])
|
||
|
return result
|