commit 478335a88525c4f6ed70d3fe070a6db4d4febcdf
Author: zhouminyang <784185055@qq.com>
Date: Mon Jun 5 15:11:03 2023 +0800
1.0
diff --git a/.idea/.gitignore b/.idea/.gitignore
new file mode 100644
index 0000000..13566b8
--- /dev/null
+++ b/.idea/.gitignore
@@ -0,0 +1,8 @@
+# Default ignored files
+/shelf/
+/workspace.xml
+# Editor-based HTTP Client requests
+/httpRequests/
+# Datasource local storage ignored files
+/dataSources/
+/dataSources.local.xml
diff --git a/.idea/image_interprebility.iml b/.idea/image_interprebility.iml
new file mode 100644
index 0000000..2b10bc0
--- /dev/null
+++ b/.idea/image_interprebility.iml
@@ -0,0 +1,12 @@
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml
new file mode 100644
index 0000000..1d8b614
--- /dev/null
+++ b/.idea/inspectionProfiles/Project_Default.xml
@@ -0,0 +1,18 @@
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml
new file mode 100644
index 0000000..105ce2d
--- /dev/null
+++ b/.idea/inspectionProfiles/profiles_settings.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/misc.xml b/.idea/misc.xml
new file mode 100644
index 0000000..cbc002c
--- /dev/null
+++ b/.idea/misc.xml
@@ -0,0 +1,4 @@
+
+
+
+
\ No newline at end of file
diff --git a/.idea/modules.xml b/.idea/modules.xml
new file mode 100644
index 0000000..13620b5
--- /dev/null
+++ b/.idea/modules.xml
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/__pycache__/utils.cpython-38.pyc b/__pycache__/utils.cpython-38.pyc
new file mode 100644
index 0000000..361d507
Binary files /dev/null and b/__pycache__/utils.cpython-38.pyc differ
diff --git a/api.py b/api.py
new file mode 100644
index 0000000..5de005e
--- /dev/null
+++ b/api.py
@@ -0,0 +1,148 @@
+import argparse
+import cv2
+import numpy as np
+import torch
+from torchvision import models
+from pytorch_grad_cam import GradCAM, \
+ HiResCAM, \
+ ScoreCAM, \
+ GradCAMPlusPlus, \
+ AblationCAM, \
+ XGradCAM, \
+ EigenCAM, \
+ EigenGradCAM, \
+ LayerCAM, \
+ FullGrad, \
+ GradCAMElementWise
+
+
+from pytorch_grad_cam import GuidedBackpropReLUModel
+from pytorch_grad_cam.utils.image import show_cam_on_image, \
+ deprocess_image, \
+ preprocess_image
+from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
+
+
+def get_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--use-cuda', action='store_true', default=False,
+ help='Use NVIDIA GPU acceleration')
+ parser.add_argument(
+ '--image-path',
+ type=str,
+ default='./examples/both.png',
+ help='Input image path')
+ parser.add_argument('--aug_smooth', action='store_true',
+ help='Apply test time augmentation to smooth the CAM')
+ parser.add_argument(
+ '--eigen_smooth',
+ action='store_true',
+ help='Reduce noise by taking the first principle componenet'
+ 'of cam_weights*activations')
+ parser.add_argument('--method', type=str, default='gradcam',
+ choices=['gradcam', 'hirescam', 'gradcam++',
+ 'scorecam', 'xgradcam',
+ 'ablationcam', 'eigencam',
+ 'eigengradcam', 'layercam', 'fullgrad'],
+ help='Can be gradcam/gradcam++/scorecam/xgradcam'
+ '/ablationcam/eigencam/eigengradcam/layercam')
+
+ args = parser.parse_args()
+ args.use_cuda = args.use_cuda and torch.cuda.is_available()
+ if args.use_cuda:
+ print('Using GPU for acceleration')
+ else:
+ print('Using CPU for computation')
+
+ return args
+
+
+def api(image_path,method,model_name,**kwargs):
+ args = get_args()
+ methods = \
+ {"gradcam": GradCAM,
+ "hirescam": HiResCAM,
+ "scorecam": ScoreCAM,
+ "gradcam++": GradCAMPlusPlus,
+ "ablationcam": AblationCAM,
+ "xgradcam": XGradCAM,
+ "eigencam": EigenCAM,
+ "eigengradcam": EigenGradCAM,
+ "layercam": LayerCAM,
+ "fullgrad": FullGrad,
+ "gradcamelementwise": GradCAMElementWise}
+
+ model = models.resnet50(pretrained=True)
+ # model = eval('models.'+model_name+'(pretrained=True)')
+ print(model)
+ # Choose the target layer you want to compute the visualization for.
+ # Usually this will be the last convolutional layer in the model.
+ # Some common choices can be:
+ # Resnet18 and 50: model.layer4
+ # VGG, densenet161: model.features[-1]
+ # mnasnet1_0: model.layers[-1]
+ # You can print the model to help chose the layer
+ # You can pass a list with several target layers,
+ # in that case the CAMs will be computed per layer and then aggregated.
+ # You can also try selecting all layers of a certain type, with e.g:
+ # from pytorch_grad_cam.utils.find_layers import find_layer_types_recursive
+ # find_layer_types_recursive(model, [torch.nn.ReLU])
+ #target_layers = [model.layer4]
+ target_layer=kwargs['target_layer']
+ target_layers = [eval(f'model.{target_layer}')]
+
+ rgb_img = cv2.imread(image_path, 1)[:, :, ::-1]
+ rgb_img = np.float32(rgb_img) / 255
+ input_tensor = preprocess_image(rgb_img,
+ mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225])
+
+ # We have to specify the target we want to generate
+ # the Class Activation Maps for.
+ # If targets is None, the highest scoring category (for every member in the batch) will be used.
+ # You can target specific categories by
+ # targets = [e.g ClassifierOutputTarget(281)]
+ targets = None
+
+ # Using the with statement ensures the context is freed, and you can
+ # recreate different CAM objects in a loop.
+ cam_algorithm = methods[method]
+ with cam_algorithm(model=model,
+ target_layers=target_layers,
+ use_cuda=args.use_cuda) as cam:
+
+ # AblationCAM and ScoreCAM have batched implementations.
+ # You can override the internal batch size for faster computation.
+ cam.batch_size = 32
+ print(args.eigen_smooth)
+ aug_smooth=kwargs['aug_smooth']
+ grayscale_cam = cam(input_tensor=input_tensor,
+ targets=targets,
+ aug_smooth=aug_smooth,
+ eigen_smooth=args.eigen_smooth)
+
+ # Here grayscale_cam has only one image in the batch
+ grayscale_cam = grayscale_cam[0, :]
+
+ cam_image = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True)
+
+ # cam_image is RGB encoded whereas "cv2.imwrite" requires BGR encoding.
+ cam_image = cv2.cvtColor(cam_image, cv2.COLOR_RGB2BGR)
+
+ gb_model = GuidedBackpropReLUModel(model=model, use_cuda=args.use_cuda)
+ gb = gb_model(input_tensor, target_category=None)
+
+ cam_mask = cv2.merge([grayscale_cam, grayscale_cam, grayscale_cam])
+ cam_gb = deprocess_image(cam_mask * gb)
+ gb = deprocess_image(gb)
+
+ cv2.imwrite(f'{method}_cam.jpg', cam_image)
+ cv2.imwrite(f'{method}_gb.jpg', gb)
+ cv2.imwrite(f'{method}_cam_gb.jpg', cam_gb)
+ return method+'_gb.jpg'
+
+kwargs={"target_layer":'layer1',
+ "aug_smooth":True,
+ "eigen_smooth":True}
+path=api('sample/both.png','fullgrad','resnet',**kwargs)
+print(path)
\ No newline at end of file
diff --git a/imagenet_1000.json b/imagenet_1000.json
new file mode 100644
index 0000000..0c06838
--- /dev/null
+++ b/imagenet_1000.json
@@ -0,0 +1 @@
+{"0": "tench, Tinca tinca", "1": "goldfish, Carassius auratus", "2": "great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias", "3": "tiger shark, Galeocerdo cuvieri", "4": "hammerhead, hammerhead shark", "5": "electric ray, crampfish, numbfish, torpedo", "6": "stingray", "7": "cock", "8": "hen", "9": "ostrich, Struthio camelus", "10": "brambling, Fringilla montifringilla", "11": "goldfinch, Carduelis carduelis", "12": "house finch, linnet, Carpodacus mexicanus", "13": "junco, snowbird", "14": "indigo bunting, indigo finch, indigo bird, Passerina cyanea", "15": "robin, American robin, Turdus migratorius", "16": "bulbul", "17": "jay", "18": "magpie", "19": "chickadee", "20": "water ouzel, dipper", "21": "kite", "22": "bald eagle, American eagle, Haliaeetus leucocephalus", "23": "vulture", "24": "great grey owl, great gray owl, Strix nebulosa", "25": "European fire salamander, Salamandra salamandra", "26": "common newt, Triturus vulgaris", "27": "eft", "28": "spotted salamander, Ambystoma maculatum", "29": "axolotl, mud puppy, Ambystoma mexicanum", "30": "bullfrog, Rana catesbeiana", "31": "tree frog, tree-frog", "32": "tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui", "33": "loggerhead, loggerhead turtle, Caretta caretta", "34": "leatherback turtle, leatherback, leathery turtle, Dermochelys coriacea", "35": "mud turtle", "36": "terrapin", "37": "box turtle, box tortoise", "38": "banded gecko", "39": "common iguana, iguana, Iguana iguana", "40": "American chameleon, anole, Anolis carolinensis", "41": "whiptail, whiptail lizard", "42": "agama", "43": "frilled lizard, Chlamydosaurus kingi", "44": "alligator lizard", "45": "Gila monster, Heloderma suspectum", "46": "green lizard, Lacerta viridis", "47": "African chameleon, Chamaeleo chamaeleon", "48": "Komodo dragon, Komodo lizard, dragon lizard, giant lizard, Varanus komodoensis", "49": "African crocodile, Nile crocodile, Crocodylus niloticus", "50": "American alligator, Alligator mississipiensis", "51": "triceratops", "52": "thunder snake, worm snake, Carphophis amoenus", "53": "ringneck snake, ring-necked snake, ring snake", "54": "hognose snake, puff adder, sand viper", "55": "green snake, grass snake", "56": "king snake, kingsnake", "57": "garter snake, grass snake", "58": "water snake", "59": "vine snake", "60": "night snake, Hypsiglena torquata", "61": "boa constrictor, Constrictor constrictor", "62": "rock python, rock snake, Python sebae", "63": "Indian cobra, Naja naja", "64": "green mamba", "65": "sea snake", "66": "horned viper, cerastes, sand viper, horned asp, Cerastes cornutus", "67": "diamondback, diamondback rattlesnake, Crotalus adamanteus", "68": "sidewinder, horned rattlesnake, Crotalus cerastes", "69": "trilobite", "70": "harvestman, daddy longlegs, Phalangium opilio", "71": "scorpion", "72": "black and gold garden spider, Argiope aurantia", "73": "barn spider, Araneus cavaticus", "74": "garden spider, Aranea diademata", "75": "black widow, Latrodectus mactans", "76": "tarantula", "77": "wolf spider, hunting spider", "78": "tick", "79": "centipede", "80": "black grouse", "81": "ptarmigan", "82": "ruffed grouse, partridge, Bonasa umbellus", "83": "prairie chicken, prairie grouse, prairie fowl", "84": "peacock", "85": "quail", "86": "partridge", "87": "African grey, African gray, Psittacus erithacus", "88": "macaw", "89": "sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita", "90": "lorikeet", "91": "coucal", "92": "bee eater", "93": "hornbill", "94": "hummingbird", "95": "jacamar", "96": "toucan", "97": "drake", "98": "red-breasted merganser, Mergus serrator", "99": "goose", "100": "black swan, Cygnus atratus", "101": "tusker", "102": "echidna, spiny anteater, anteater", "103": "platypus, duckbill, duckbilled platypus, duck-billed platypus, Ornithorhynchus anatinus", "104": "wallaby, brush kangaroo", "105": "koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus", "106": "wombat", "107": "jellyfish", "108": "sea anemone, anemone", "109": "brain coral", "110": "flatworm, platyhelminth", "111": "nematode, nematode worm, roundworm", "112": "conch", "113": "snail", "114": "slug", "115": "sea slug, nudibranch", "116": "chiton, coat-of-mail shell, sea cradle, polyplacophore", "117": "chambered nautilus, pearly nautilus, nautilus", "118": "Dungeness crab, Cancer magister", "119": "rock crab, Cancer irroratus", "120": "fiddler crab", "121": "king crab, Alaska crab, Alaskan king crab, Alaska king crab, Paralithodes camtschatica", "122": "American lobster, Northern lobster, Maine lobster, Homarus americanus", "123": "spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish", "124": "crayfish, crawfish, crawdad, crawdaddy", "125": "hermit crab", "126": "isopod", "127": "white stork, Ciconia ciconia", "128": "black stork, Ciconia nigra", "129": "spoonbill", "130": "flamingo", "131": "little blue heron, Egretta caerulea", "132": "American egret, great white heron, Egretta albus", "133": "bittern", "134": "crane", "135": "limpkin, Aramus pictus", "136": "European gallinule, Porphyrio porphyrio", "137": "American coot, marsh hen, mud hen, water hen, Fulica americana", "138": "bustard", "139": "ruddy turnstone, Arenaria interpres", "140": "red-backed sandpiper, dunlin, Erolia alpina", "141": "redshank, Tringa totanus", "142": "dowitcher", "143": "oystercatcher, oyster catcher", "144": "pelican", "145": "king penguin, Aptenodytes patagonica", "146": "albatross, mollymawk", "147": "grey whale, gray whale, devilfish, Eschrichtius gibbosus, Eschrichtius robustus", "148": "killer whale, killer, orca, grampus, sea wolf, Orcinus orca", "149": "dugong, Dugong dugon", "150": "sea lion", "151": "Chihuahua", "152": "Japanese spaniel", "153": "Maltese dog, Maltese terrier, Maltese", "154": "Pekinese, Pekingese, Peke", "155": "Shih-Tzu", "156": "Blenheim spaniel", "157": "papillon", "158": "toy terrier", "159": "Rhodesian ridgeback", "160": "Afghan hound, Afghan", "161": "basset, basset hound", "162": "beagle", "163": "bloodhound, sleuthhound", "164": "bluetick", "165": "black-and-tan coonhound", "166": "Walker hound, Walker foxhound", "167": "English foxhound", "168": "redbone", "169": "borzoi, Russian wolfhound", "170": "Irish wolfhound", "171": "Italian greyhound", "172": "whippet", "173": "Ibizan hound, Ibizan Podenco", "174": "Norwegian elkhound, elkhound", "175": "otterhound, otter hound", "176": "Saluki, gazelle hound", "177": "Scottish deerhound, deerhound", "178": "Weimaraner", "179": "Staffordshire bullterrier, Staffordshire bull terrier", "180": "American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier", "181": "Bedlington terrier", "182": "Border terrier", "183": "Kerry blue terrier", "184": "Irish terrier", "185": "Norfolk terrier", "186": "Norwich terrier", "187": "Yorkshire terrier", "188": "wire-haired fox terrier", "189": "Lakeland terrier", "190": "Sealyham terrier, Sealyham", "191": "Airedale, Airedale terrier", "192": "cairn, cairn terrier", "193": "Australian terrier", "194": "Dandie Dinmont, Dandie Dinmont terrier", "195": "Boston bull, Boston terrier", "196": "miniature schnauzer", "197": "giant schnauzer", "198": "standard schnauzer", "199": "Scotch terrier, Scottish terrier, Scottie", "200": "Tibetan terrier, chrysanthemum dog", "201": "silky terrier, Sydney silky", "202": "soft-coated wheaten terrier", "203": "West Highland white terrier", "204": "Lhasa, Lhasa apso", "205": "flat-coated retriever", "206": "curly-coated retriever", "207": "golden retriever", "208": "Labrador retriever", "209": "Chesapeake Bay retriever", "210": "German short-haired pointer", "211": "vizsla, Hungarian pointer", "212": "English setter", "213": "Irish setter, red setter", "214": "Gordon setter", "215": "Brittany spaniel", "216": "clumber, clumber spaniel", "217": "English springer, English springer spaniel", "218": "Welsh springer spaniel", "219": "cocker spaniel, English cocker spaniel, cocker", "220": "Sussex spaniel", "221": "Irish water spaniel", "222": "kuvasz", "223": "schipperke", "224": "groenendael", "225": "malinois", "226": "briard", "227": "kelpie", "228": "komondor", "229": "Old English sheepdog, bobtail", "230": "Shetland sheepdog, Shetland sheep dog, Shetland", "231": "collie", "232": "Border collie", "233": "Bouvier des Flandres, Bouviers des Flandres", "234": "Rottweiler", "235": "German shepherd, German shepherd dog, German police dog, alsatian", "236": "Doberman, Doberman pinscher", "237": "miniature pinscher", "238": "Greater Swiss Mountain dog", "239": "Bernese mountain dog", "240": "Appenzeller", "241": "EntleBucher", "242": "boxer", "243": "bull mastiff", "244": "Tibetan mastiff", "245": "French bulldog", "246": "Great Dane", "247": "Saint Bernard, St Bernard", "248": "Eskimo dog, husky", "249": "malamute, malemute, Alaskan malamute", "250": "Siberian husky", "251": "dalmatian, coach dog, carriage dog", "252": "affenpinscher, monkey pinscher, monkey dog", "253": "basenji", "254": "pug, pug-dog", "255": "Leonberg", "256": "Newfoundland, Newfoundland dog", "257": "Great Pyrenees", "258": "Samoyed, Samoyede", "259": "Pomeranian", "260": "chow, chow chow", "261": "keeshond", "262": "Brabancon griffon", "263": "Pembroke, Pembroke Welsh corgi", "264": "Cardigan, Cardigan Welsh corgi", "265": "toy poodle", "266": "miniature poodle", "267": "standard poodle", "268": "Mexican hairless", "269": "timber wolf, grey wolf, gray wolf, Canis lupus", "270": "white wolf, Arctic wolf, Canis lupus tundrarum", "271": "red wolf, maned wolf, Canis rufus, Canis niger", "272": "coyote, prairie wolf, brush wolf, Canis latrans", "273": "dingo, warrigal, warragal, Canis dingo", "274": "dhole, Cuon alpinus", "275": "African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus", "276": "hyena, hyaena", "277": "red fox, Vulpes vulpes", "278": "kit fox, Vulpes macrotis", "279": "Arctic fox, white fox, Alopex lagopus", "280": "grey fox, gray fox, Urocyon cinereoargenteus", "281": "tabby, tabby cat", "282": "tiger cat", "283": "Persian cat", "284": "Siamese cat, Siamese", "285": "Egyptian cat", "286": "cougar, puma, catamount, mountain lion, painter, panther, Felis concolor", "287": "lynx, catamount", "288": "leopard, Panthera pardus", "289": "snow leopard, ounce, Panthera uncia", "290": "jaguar, panther, Panthera onca, Felis onca", "291": "lion, king of beasts, Panthera leo", "292": "tiger, Panthera tigris", "293": "cheetah, chetah, Acinonyx jubatus", "294": "brown bear, bruin, Ursus arctos", "295": "American black bear, black bear, Ursus americanus, Euarctos americanus", "296": "ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus", "297": "sloth bear, Melursus ursinus, Ursus ursinus", "298": "mongoose", "299": "meerkat, mierkat", "300": "tiger beetle", "301": "ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle", "302": "ground beetle, carabid beetle", "303": "long-horned beetle, longicorn, longicorn beetle", "304": "leaf beetle, chrysomelid", "305": "dung beetle", "306": "rhinoceros beetle", "307": "weevil", "308": "fly", "309": "bee", "310": "ant, emmet, pismire", "311": "grasshopper, hopper", "312": "cricket", "313": "walking stick, walkingstick, stick insect", "314": "cockroach, roach", "315": "mantis, mantid", "316": "cicada, cicala", "317": "leafhopper", "318": "lacewing, lacewing fly", "319": "dragonfly, darning needle, devil's darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk", "320": "damselfly", "321": "admiral", "322": "ringlet, ringlet butterfly", "323": "monarch, monarch butterfly, milkweed butterfly, Danaus plexippus", "324": "cabbage butterfly", "325": "sulphur butterfly, sulfur butterfly", "326": "lycaenid, lycaenid butterfly", "327": "starfish, sea star", "328": "sea urchin", "329": "sea cucumber, holothurian", "330": "wood rabbit, cottontail, cottontail rabbit", "331": "hare", "332": "Angora, Angora rabbit", "333": "hamster", "334": "porcupine, hedgehog", "335": "fox squirrel, eastern fox squirrel, Sciurus niger", "336": "marmot", "337": "beaver", "338": "guinea pig, Cavia cobaya", "339": "sorrel", "340": "zebra", "341": "hog, pig, grunter, squealer, Sus scrofa", "342": "wild boar, boar, Sus scrofa", "343": "warthog", "344": "hippopotamus, hippo, river horse, Hippopotamus amphibius", "345": "ox", "346": "water buffalo, water ox, Asiatic buffalo, Bubalus bubalis", "347": "bison", "348": "ram, tup", "349": "bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis", "350": "ibex, Capra ibex", "351": "hartebeest", "352": "impala, Aepyceros melampus", "353": "gazelle", "354": "Arabian camel, dromedary, Camelus dromedarius", "355": "llama", "356": "weasel", "357": "mink", "358": "polecat, fitch, foulmart, foumart, Mustela putorius", "359": "black-footed ferret, ferret, Mustela nigripes", "360": "otter", "361": "skunk, polecat, wood pussy", "362": "badger", "363": "armadillo", "364": "three-toed sloth, ai, Bradypus tridactylus", "365": "orangutan, orang, orangutang, Pongo pygmaeus", "366": "gorilla, Gorilla gorilla", "367": "chimpanzee, chimp, Pan troglodytes", "368": "gibbon, Hylobates lar", "369": "siamang, Hylobates syndactylus, Symphalangus syndactylus", "370": "guenon, guenon monkey", "371": "patas, hussar monkey, Erythrocebus patas", "372": "baboon", "373": "macaque", "374": "langur", "375": "colobus, colobus monkey", "376": "proboscis monkey, Nasalis larvatus", "377": "marmoset", "378": "capuchin, ringtail, Cebus capucinus", "379": "howler monkey, howler", "380": "titi, titi monkey", "381": "spider monkey, Ateles geoffroyi", "382": "squirrel monkey, Saimiri sciureus", "383": "Madagascar cat, ring-tailed lemur, Lemur catta", "384": "indri, indris, Indri indri, Indri brevicaudatus", "385": "Indian elephant, Elephas maximus", "386": "African elephant, Loxodonta africana", "387": "lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens", "388": "giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca", "389": "barracouta, snoek", "390": "eel", "391": "coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus kisutch", "392": "rock beauty, Holocanthus tricolor", "393": "anemone fish", "394": "sturgeon", "395": "gar, garfish, garpike, billfish, Lepisosteus osseus", "396": "lionfish", "397": "puffer, pufferfish, blowfish, globefish", "398": "abacus", "399": "abaya", "400": "academic gown, academic robe, judge's robe", "401": "accordion, piano accordion, squeeze box", "402": "acoustic guitar", "403": "aircraft carrier, carrier, flattop, attack aircraft carrier", "404": "airliner", "405": "airship, dirigible", "406": "altar", "407": "ambulance", "408": "amphibian, amphibious vehicle", "409": "analog clock", "410": "apiary, bee house", "411": "apron", "412": "ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin", "413": "assault rifle, assault gun", "414": "backpack, back pack, knapsack, packsack, rucksack, haversack", "415": "bakery, bakeshop, bakehouse", "416": "balance beam, beam", "417": "balloon", "418": "ballpoint, ballpoint pen, ballpen, Biro", "419": "Band Aid", "420": "banjo", "421": "bannister, banister, balustrade, balusters, handrail", "422": "barbell", "423": "barber chair", "424": "barbershop", "425": "barn", "426": "barometer", "427": "barrel, cask", "428": "barrow, garden cart, lawn cart, wheelbarrow", "429": "baseball", "430": "basketball", "431": "bassinet", "432": "bassoon", "433": "bathing cap, swimming cap", "434": "bath towel", "435": "bathtub, bathing tub, bath, tub", "436": "beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon", "437": "beacon, lighthouse, beacon light, pharos", "438": "beaker", "439": "bearskin, busby, shako", "440": "beer bottle", "441": "beer glass", "442": "bell cote, bell cot", "443": "bib", "444": "bicycle-built-for-two, tandem bicycle, tandem", "445": "bikini, two-piece", "446": "binder, ring-binder", "447": "binoculars, field glasses, opera glasses", "448": "birdhouse", "449": "boathouse", "450": "bobsled, bobsleigh, bob", "451": "bolo tie, bolo, bola tie, bola", "452": "bonnet, poke bonnet", "453": "bookcase", "454": "bookshop, bookstore, bookstall", "455": "bottlecap", "456": "bow", "457": "bow tie, bow-tie, bowtie", "458": "brass, memorial tablet, plaque", "459": "brassiere, bra, bandeau", "460": "breakwater, groin, groyne, mole, bulwark, seawall, jetty", "461": "breastplate, aegis, egis", "462": "broom", "463": "bucket, pail", "464": "buckle", "465": "bulletproof vest", "466": "bullet train, bullet", "467": "butcher shop, meat market", "468": "cab, hack, taxi, taxicab", "469": "caldron, cauldron", "470": "candle, taper, wax light", "471": "cannon", "472": "canoe", "473": "can opener, tin opener", "474": "cardigan", "475": "car mirror", "476": "carousel, carrousel, merry-go-round, roundabout, whirligig", "477": "carpenter's kit, tool kit", "478": "carton", "479": "car wheel", "480": "cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM", "481": "cassette", "482": "cassette player", "483": "castle", "484": "catamaran", "485": "CD player", "486": "cello, violoncello", "487": "cellular telephone, cellular phone, cellphone, cell, mobile phone", "488": "chain", "489": "chainlink fence", "490": "chain mail, ring mail, mail, chain armor, chain armour, ring armor, ring armour", "491": "chain saw, chainsaw", "492": "chest", "493": "chiffonier, commode", "494": "chime, bell, gong", "495": "china cabinet, china closet", "496": "Christmas stocking", "497": "church, church building", "498": "cinema, movie theater, movie theatre, movie house, picture palace", "499": "cleaver, meat cleaver, chopper", "500": "cliff dwelling", "501": "cloak", "502": "clog, geta, patten, sabot", "503": "cocktail shaker", "504": "coffee mug", "505": "coffeepot", "506": "coil, spiral, volute, whorl, helix", "507": "combination lock", "508": "computer keyboard, keypad", "509": "confectionery, confectionary, candy store", "510": "container ship, containership, container vessel", "511": "convertible", "512": "corkscrew, bottle screw", "513": "cornet, horn, trumpet, trump", "514": "cowboy boot", "515": "cowboy hat, ten-gallon hat", "516": "cradle", "517": "crane", "518": "crash helmet", "519": "crate", "520": "crib, cot", "521": "Crock Pot", "522": "croquet ball", "523": "crutch", "524": "cuirass", "525": "dam, dike, dyke", "526": "desk", "527": "desktop computer", "528": "dial telephone, dial phone", "529": "diaper, nappy, napkin", "530": "digital clock", "531": "digital watch", "532": "dining table, board", "533": "dishrag, dishcloth", "534": "dishwasher, dish washer, dishwashing machine", "535": "disk brake, disc brake", "536": "dock, dockage, docking facility", "537": "dogsled, dog sled, dog sleigh", "538": "dome", "539": "doormat, welcome mat", "540": "drilling platform, offshore rig", "541": "drum, membranophone, tympan", "542": "drumstick", "543": "dumbbell", "544": "Dutch oven", "545": "electric fan, blower", "546": "electric guitar", "547": "electric locomotive", "548": "entertainment center", "549": "envelope", "550": "espresso maker", "551": "face powder", "552": "feather boa, boa", "553": "file, file cabinet, filing cabinet", "554": "fireboat", "555": "fire engine, fire truck", "556": "fire screen, fireguard", "557": "flagpole, flagstaff", "558": "flute, transverse flute", "559": "folding chair", "560": "football helmet", "561": "forklift", "562": "fountain", "563": "fountain pen", "564": "four-poster", "565": "freight car", "566": "French horn, horn", "567": "frying pan, frypan, skillet", "568": "fur coat", "569": "garbage truck, dustcart", "570": "gasmask, respirator, gas helmet", "571": "gas pump, gasoline pump, petrol pump, island dispenser", "572": "goblet", "573": "go-kart", "574": "golf ball", "575": "golfcart, golf cart", "576": "gondola", "577": "gong, tam-tam", "578": "gown", "579": "grand piano, grand", "580": "greenhouse, nursery, glasshouse", "581": "grille, radiator grille", "582": "grocery store, grocery, food market, market", "583": "guillotine", "584": "hair slide", "585": "hair spray", "586": "half track", "587": "hammer", "588": "hamper", "589": "hand blower, blow dryer, blow drier, hair dryer, hair drier", "590": "hand-held computer, hand-held microcomputer", "591": "handkerchief, hankie, hanky, hankey", "592": "hard disc, hard disk, fixed disk", "593": "harmonica, mouth organ, harp, mouth harp", "594": "harp", "595": "harvester, reaper", "596": "hatchet", "597": "holster", "598": "home theater, home theatre", "599": "honeycomb", "600": "hook, claw", "601": "hoopskirt, crinoline", "602": "horizontal bar, high bar", "603": "horse cart, horse-cart", "604": "hourglass", "605": "iPod", "606": "iron, smoothing iron", "607": "jack-o'-lantern", "608": "jean, blue jean, denim", "609": "jeep, landrover", "610": "jersey, T-shirt, tee shirt", "611": "jigsaw puzzle", "612": "jinrikisha, ricksha, rickshaw", "613": "joystick", "614": "kimono", "615": "knee pad", "616": "knot", "617": "lab coat, laboratory coat", "618": "ladle", "619": "lampshade, lamp shade", "620": "laptop, laptop computer", "621": "lawn mower, mower", "622": "lens cap, lens cover", "623": "letter opener, paper knife, paperknife", "624": "library", "625": "lifeboat", "626": "lighter, light, igniter, ignitor", "627": "limousine, limo", "628": "liner, ocean liner", "629": "lipstick, lip rouge", "630": "Loafer", "631": "lotion", "632": "loudspeaker, speaker, speaker unit, loudspeaker system, speaker system", "633": "loupe, jeweler's loupe", "634": "lumbermill, sawmill", "635": "magnetic compass", "636": "mailbag, postbag", "637": "mailbox, letter box", "638": "maillot", "639": "maillot, tank suit", "640": "manhole cover", "641": "maraca", "642": "marimba, xylophone", "643": "mask", "644": "matchstick", "645": "maypole", "646": "maze, labyrinth", "647": "measuring cup", "648": "medicine chest, medicine cabinet", "649": "megalith, megalithic structure", "650": "microphone, mike", "651": "microwave, microwave oven", "652": "military uniform", "653": "milk can", "654": "minibus", "655": "miniskirt, mini", "656": "minivan", "657": "missile", "658": "mitten", "659": "mixing bowl", "660": "mobile home, manufactured home", "661": "Model T", "662": "modem", "663": "monastery", "664": "monitor", "665": "moped", "666": "mortar", "667": "mortarboard", "668": "mosque", "669": "mosquito net", "670": "motor scooter, scooter", "671": "mountain bike, all-terrain bike, off-roader", "672": "mountain tent", "673": "mouse, computer mouse", "674": "mousetrap", "675": "moving van", "676": "muzzle", "677": "nail", "678": "neck brace", "679": "necklace", "680": "nipple", "681": "notebook, notebook computer", "682": "obelisk", "683": "oboe, hautboy, hautbois", "684": "ocarina, sweet potato", "685": "odometer, hodometer, mileometer, milometer", "686": "oil filter", "687": "organ, pipe organ", "688": "oscilloscope, scope, cathode-ray oscilloscope, CRO", "689": "overskirt", "690": "oxcart", "691": "oxygen mask", "692": "packet", "693": "paddle, boat paddle", "694": "paddlewheel, paddle wheel", "695": "padlock", "696": "paintbrush", "697": "pajama, pyjama, pj's, jammies", "698": "palace", "699": "panpipe, pandean pipe, syrinx", "700": "paper towel", "701": "parachute, chute", "702": "parallel bars, bars", "703": "park bench", "704": "parking meter", "705": "passenger car, coach, carriage", "706": "patio, terrace", "707": "pay-phone, pay-station", "708": "pedestal, plinth, footstall", "709": "pencil box, pencil case", "710": "pencil sharpener", "711": "perfume, essence", "712": "Petri dish", "713": "photocopier", "714": "pick, plectrum, plectron", "715": "pickelhaube", "716": "picket fence, paling", "717": "pickup, pickup truck", "718": "pier", "719": "piggy bank, penny bank", "720": "pill bottle", "721": "pillow", "722": "ping-pong ball", "723": "pinwheel", "724": "pirate, pirate ship", "725": "pitcher, ewer", "726": "plane, carpenter's plane, woodworking plane", "727": "planetarium", "728": "plastic bag", "729": "plate rack", "730": "plow, plough", "731": "plunger, plumber's helper", "732": "Polaroid camera, Polaroid Land camera", "733": "pole", "734": "police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria", "735": "poncho", "736": "pool table, billiard table, snooker table", "737": "pop bottle, soda bottle", "738": "pot, flowerpot", "739": "potter's wheel", "740": "power drill", "741": "prayer rug, prayer mat", "742": "printer", "743": "prison, prison house", "744": "projectile, missile", "745": "projector", "746": "puck, hockey puck", "747": "punching bag, punch bag, punching ball, punchball", "748": "purse", "749": "quill, quill pen", "750": "quilt, comforter, comfort, puff", "751": "racer, race car, racing car", "752": "racket, racquet", "753": "radiator", "754": "radio, wireless", "755": "radio telescope, radio reflector", "756": "rain barrel", "757": "recreational vehicle, RV, R.V.", "758": "reel", "759": "reflex camera", "760": "refrigerator, icebox", "761": "remote control, remote", "762": "restaurant, eating house, eating place, eatery", "763": "revolver, six-gun, six-shooter", "764": "rifle", "765": "rocking chair, rocker", "766": "rotisserie", "767": "rubber eraser, rubber, pencil eraser", "768": "rugby ball", "769": "rule, ruler", "770": "running shoe", "771": "safe", "772": "safety pin", "773": "saltshaker, salt shaker", "774": "sandal", "775": "sarong", "776": "sax, saxophone", "777": "scabbard", "778": "scale, weighing machine", "779": "school bus", "780": "schooner", "781": "scoreboard", "782": "screen, CRT screen", "783": "screw", "784": "screwdriver", "785": "seat belt, seatbelt", "786": "sewing machine", "787": "shield, buckler", "788": "shoe shop, shoe-shop, shoe store", "789": "shoji", "790": "shopping basket", "791": "shopping cart", "792": "shovel", "793": "shower cap", "794": "shower curtain", "795": "ski", "796": "ski mask", "797": "sleeping bag", "798": "slide rule, slipstick", "799": "sliding door", "800": "slot, one-armed bandit", "801": "snorkel", "802": "snowmobile", "803": "snowplow, snowplough", "804": "soap dispenser", "805": "soccer ball", "806": "sock", "807": "solar dish, solar collector, solar furnace", "808": "sombrero", "809": "soup bowl", "810": "space bar", "811": "space heater", "812": "space shuttle", "813": "spatula", "814": "speedboat", "815": "spider web, spider's web", "816": "spindle", "817": "sports car, sport car", "818": "spotlight, spot", "819": "stage", "820": "steam locomotive", "821": "steel arch bridge", "822": "steel drum", "823": "stethoscope", "824": "stole", "825": "stone wall", "826": "stopwatch, stop watch", "827": "stove", "828": "strainer", "829": "streetcar, tram, tramcar, trolley, trolley car", "830": "stretcher", "831": "studio couch, day bed", "832": "stupa, tope", "833": "submarine, pigboat, sub, U-boat", "834": "suit, suit of clothes", "835": "sundial", "836": "sunglass", "837": "sunglasses, dark glasses, shades", "838": "sunscreen, sunblock, sun blocker", "839": "suspension bridge", "840": "swab, swob, mop", "841": "sweatshirt", "842": "swimming trunks, bathing trunks", "843": "swing", "844": "switch, electric switch, electrical switch", "845": "syringe", "846": "table lamp", "847": "tank, army tank, armored combat vehicle, armoured combat vehicle", "848": "tape player", "849": "teapot", "850": "teddy, teddy bear", "851": "television, television system", "852": "tennis ball", "853": "thatch, thatched roof", "854": "theater curtain, theatre curtain", "855": "thimble", "856": "thresher, thrasher, threshing machine", "857": "throne", "858": "tile roof", "859": "toaster", "860": "tobacco shop, tobacconist shop, tobacconist", "861": "toilet seat", "862": "torch", "863": "totem pole", "864": "tow truck, tow car, wrecker", "865": "toyshop", "866": "tractor", "867": "trailer truck, tractor trailer, trucking rig, rig, articulated lorry, semi", "868": "tray", "869": "trench coat", "870": "tricycle, trike, velocipede", "871": "trimaran", "872": "tripod", "873": "triumphal arch", "874": "trolleybus, trolley coach, trackless trolley", "875": "trombone", "876": "tub, vat", "877": "turnstile", "878": "typewriter keyboard", "879": "umbrella", "880": "unicycle, monocycle", "881": "upright, upright piano", "882": "vacuum, vacuum cleaner", "883": "vase", "884": "vault", "885": "velvet", "886": "vending machine", "887": "vestment", "888": "viaduct", "889": "violin, fiddle", "890": "volleyball", "891": "waffle iron", "892": "wall clock", "893": "wallet, billfold, notecase, pocketbook", "894": "wardrobe, closet, press", "895": "warplane, military plane", "896": "washbasin, handbasin, washbowl, lavabo, wash-hand basin", "897": "washer, automatic washer, washing machine", "898": "water bottle", "899": "water jug", "900": "water tower", "901": "whiskey jug", "902": "whistle", "903": "wig", "904": "window screen", "905": "window shade", "906": "Windsor tie", "907": "wine bottle", "908": "wing", "909": "wok", "910": "wooden spoon", "911": "wool, woolen, woollen", "912": "worm fence, snake fence, snake-rail fence, Virginia fence", "913": "wreck", "914": "yawl", "915": "yurt", "916": "web site, website, internet site, site", "917": "comic book", "918": "crossword puzzle, crossword", "919": "street sign", "920": "traffic light, traffic signal, stoplight", "921": "book jacket, dust cover, dust jacket, dust wrapper", "922": "menu", "923": "plate", "924": "guacamole", "925": "consomme", "926": "hot pot, hotpot", "927": "trifle", "928": "ice cream, icecream", "929": "ice lolly, lolly, lollipop, popsicle", "930": "French loaf", "931": "bagel, beigel", "932": "pretzel", "933": "cheeseburger", "934": "hotdog, hot dog, red hot", "935": "mashed potato", "936": "head cabbage", "937": "broccoli", "938": "cauliflower", "939": "zucchini, courgette", "940": "spaghetti squash", "941": "acorn squash", "942": "butternut squash", "943": "cucumber, cuke", "944": "artichoke, globe artichoke", "945": "bell pepper", "946": "cardoon", "947": "mushroom", "948": "Granny Smith", "949": "strawberry", "950": "orange", "951": "lemon", "952": "fig", "953": "pineapple, ananas", "954": "banana", "955": "jackfruit, jak, jack", "956": "custard apple", "957": "pomegranate", "958": "hay", "959": "carbonara", "960": "chocolate sauce, chocolate syrup", "961": "dough", "962": "meat loaf, meatloaf", "963": "pizza, pizza pie", "964": "potpie", "965": "burrito", "966": "red wine", "967": "espresso", "968": "cup", "969": "eggnog", "970": "alp", "971": "bubble", "972": "cliff, drop, drop-off", "973": "coral reef", "974": "geyser", "975": "lakeside, lakeshore", "976": "promontory, headland, head, foreland", "977": "sandbar, sand bar", "978": "seashore, coast, seacoast, sea-coast", "979": "valley, vale", "980": "volcano", "981": "ballplayer, baseball player", "982": "groom, bridegroom", "983": "scuba diver", "984": "rapeseed", "985": "daisy", "986": "yellow lady's slipper, yellow lady-slipper, Cypripedium calceolus, Cypripedium parviflorum", "987": "corn", "988": "acorn", "989": "hip, rose hip, rosehip", "990": "buckeye, horse chestnut, conker", "991": "coral fungus", "992": "agaric", "993": "gyromitra", "994": "stinkhorn, carrion fungus", "995": "earthstar", "996": "hen-of-the-woods, hen of the woods, Polyporus frondosus, Grifola frondosa", "997": "bolete", "998": "ear, spike, capitulum", "999": "toilet tissue, toilet paper, bathroom tissue"}
\ No newline at end of file
diff --git a/imagenet_1000.txt b/imagenet_1000.txt
new file mode 100644
index 0000000..376e180
--- /dev/null
+++ b/imagenet_1000.txt
@@ -0,0 +1,1000 @@
+0 tench, Tinca tinca
+1 goldfish, Carassius auratus
+2 great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias
+3 tiger shark, Galeocerdo cuvieri
+4 hammerhead, hammerhead shark
+5 electric ray, crampfish, numbfish, torpedo
+6 stingray
+7 cock
+8 hen
+9 ostrich, Struthio camelus
+10 brambling, Fringilla montifringilla
+11 goldfinch, Carduelis carduelis
+12 house finch, linnet, Carpodacus mexicanus
+13 junco, snowbird
+14 indigo bunting, indigo finch, indigo bird, Passerina cyanea
+15 robin, American robin, Turdus migratorius
+16 bulbul
+17 jay
+18 magpie
+19 chickadee
+20 water ouzel, dipper
+21 kite
+22 bald eagle, American eagle, Haliaeetus leucocephalus
+23 vulture
+24 great grey owl, great gray owl, Strix nebulosa
+25 European fire salamander, Salamandra salamandra
+26 common newt, Triturus vulgaris
+27 eft
+28 spotted salamander, Ambystoma maculatum
+29 axolotl, mud puppy, Ambystoma mexicanum
+30 bullfrog, Rana catesbeiana
+31 tree frog, tree-frog
+32 tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui
+33 loggerhead, loggerhead turtle, Caretta caretta
+34 leatherback turtle, leatherback, leathery turtle, Dermochelys coriacea
+35 mud turtle
+36 terrapin
+37 box turtle, box tortoise
+38 banded gecko
+39 common iguana, iguana, Iguana iguana
+40 American chameleon, anole, Anolis carolinensis
+41 whiptail, whiptail lizard
+42 agama
+43 frilled lizard, Chlamydosaurus kingi
+44 alligator lizard
+45 Gila monster, Heloderma suspectum
+46 green lizard, Lacerta viridis
+47 African chameleon, Chamaeleo chamaeleon
+48 Komodo dragon, Komodo lizard, dragon lizard, giant lizard, Varanus komodoensis
+49 African crocodile, Nile crocodile, Crocodylus niloticus
+50 American alligator, Alligator mississipiensis
+51 triceratops
+52 thunder snake, worm snake, Carphophis amoenus
+53 ringneck snake, ring-necked snake, ring snake
+54 hognose snake, puff adder, sand viper
+55 green snake, grass snake
+56 king snake, kingsnake
+57 garter snake, grass snake
+58 water snake
+59 vine snake
+60 night snake, Hypsiglena torquata
+61 boa constrictor, Constrictor constrictor
+62 rock python, rock snake, Python sebae
+63 Indian cobra, Naja naja
+64 green mamba
+65 sea snake
+66 horned viper, cerastes, sand viper, horned asp, Cerastes cornutus
+67 diamondback, diamondback rattlesnake, Crotalus adamanteus
+68 sidewinder, horned rattlesnake, Crotalus cerastes
+69 trilobite
+70 harvestman, daddy longlegs, Phalangium opilio
+71 scorpion
+72 black and gold garden spider, Argiope aurantia
+73 barn spider, Araneus cavaticus
+74 garden spider, Aranea diademata
+75 black widow, Latrodectus mactans
+76 tarantula
+77 wolf spider, hunting spider
+78 tick
+79 centipede
+80 black grouse
+81 ptarmigan
+82 ruffed grouse, partridge, Bonasa umbellus
+83 prairie chicken, prairie grouse, prairie fowl
+84 peacock
+85 quail
+86 partridge
+87 African grey, African gray, Psittacus erithacus
+88 macaw
+89 sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita
+90 lorikeet
+91 coucal
+92 bee eater
+93 hornbill
+94 hummingbird
+95 jacamar
+96 toucan
+97 drake
+98 red-breasted merganser, Mergus serrator
+99 goose
+100 black swan, Cygnus atratus
+101 tusker
+102 echidna, spiny anteater, anteater
+103 platypus, duckbill, duckbilled platypus, duck-billed platypus, Ornithorhynchus anatinus
+104 wallaby, brush kangaroo
+105 koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus
+106 wombat
+107 jellyfish
+108 sea anemone, anemone
+109 brain coral
+110 flatworm, platyhelminth
+111 nematode, nematode worm, roundworm
+112 conch
+113 snail
+114 slug
+115 sea slug, nudibranch
+116 chiton, coat-of-mail shell, sea cradle, polyplacophore
+117 chambered nautilus, pearly nautilus, nautilus
+118 Dungeness crab, Cancer magister
+119 rock crab, Cancer irroratus
+120 fiddler crab
+121 king crab, Alaska crab, Alaskan king crab, Alaska king crab, Paralithodes camtschatica
+122 American lobster, Northern lobster, Maine lobster, Homarus americanus
+123 spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish
+124 crayfish, crawfish, crawdad, crawdaddy
+125 hermit crab
+126 isopod
+127 white stork, Ciconia ciconia
+128 black stork, Ciconia nigra
+129 spoonbill
+130 flamingo
+131 little blue heron, Egretta caerulea
+132 American egret, great white heron, Egretta albus
+133 bittern
+134 crane
+135 limpkin, Aramus pictus
+136 European gallinule, Porphyrio porphyrio
+137 American coot, marsh hen, mud hen, water hen, Fulica americana
+138 bustard
+139 ruddy turnstone, Arenaria interpres
+140 red-backed sandpiper, dunlin, Erolia alpina
+141 redshank, Tringa totanus
+142 dowitcher
+143 oystercatcher, oyster catcher
+144 pelican
+145 king penguin, Aptenodytes patagonica
+146 albatross, mollymawk
+147 grey whale, gray whale, devilfish, Eschrichtius gibbosus, Eschrichtius robustus
+148 killer whale, killer, orca, grampus, sea wolf, Orcinus orca
+149 dugong, Dugong dugon
+150 sea lion
+151 Chihuahua
+152 Japanese spaniel
+153 Maltese dog, Maltese terrier, Maltese
+154 Pekinese, Pekingese, Peke
+155 Shih-Tzu
+156 Blenheim spaniel
+157 papillon
+158 toy terrier
+159 Rhodesian ridgeback
+160 Afghan hound, Afghan
+161 basset, basset hound
+162 beagle
+163 bloodhound, sleuthhound
+164 bluetick
+165 black-and-tan coonhound
+166 Walker hound, Walker foxhound
+167 English foxhound
+168 redbone
+169 borzoi, Russian wolfhound
+170 Irish wolfhound
+171 Italian greyhound
+172 whippet
+173 Ibizan hound, Ibizan Podenco
+174 Norwegian elkhound, elkhound
+175 otterhound, otter hound
+176 Saluki, gazelle hound
+177 Scottish deerhound, deerhound
+178 Weimaraner
+179 Staffordshire bullterrier, Staffordshire bull terrier
+180 American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier
+181 Bedlington terrier
+182 Border terrier
+183 Kerry blue terrier
+184 Irish terrier
+185 Norfolk terrier
+186 Norwich terrier
+187 Yorkshire terrier
+188 wire-haired fox terrier
+189 Lakeland terrier
+190 Sealyham terrier, Sealyham
+191 Airedale, Airedale terrier
+192 cairn, cairn terrier
+193 Australian terrier
+194 Dandie Dinmont, Dandie Dinmont terrier
+195 Boston bull, Boston terrier
+196 miniature schnauzer
+197 giant schnauzer
+198 standard schnauzer
+199 Scotch terrier, Scottish terrier, Scottie
+200 Tibetan terrier, chrysanthemum dog
+201 silky terrier, Sydney silky
+202 soft-coated wheaten terrier
+203 West Highland white terrier
+204 Lhasa, Lhasa apso
+205 flat-coated retriever
+206 curly-coated retriever
+207 golden retriever
+208 Labrador retriever
+209 Chesapeake Bay retriever
+210 German short-haired pointer
+211 vizsla, Hungarian pointer
+212 English setter
+213 Irish setter, red setter
+214 Gordon setter
+215 Brittany spaniel
+216 clumber, clumber spaniel
+217 English springer, English springer spaniel
+218 Welsh springer spaniel
+219 cocker spaniel, English cocker spaniel, cocker
+220 Sussex spaniel
+221 Irish water spaniel
+222 kuvasz
+223 schipperke
+224 groenendael
+225 malinois
+226 briard
+227 kelpie
+228 komondor
+229 Old English sheepdog, bobtail
+230 Shetland sheepdog, Shetland sheep dog, Shetland
+231 collie
+232 Border collie
+233 Bouvier des Flandres, Bouviers des Flandres
+234 Rottweiler
+235 German shepherd, German shepherd dog, German police dog, alsatian
+236 Doberman, Doberman pinscher
+237 miniature pinscher
+238 Greater Swiss Mountain dog
+239 Bernese mountain dog
+240 Appenzeller
+241 EntleBucher
+242 boxer
+243 bull mastiff
+244 Tibetan mastiff
+245 French bulldog
+246 Great Dane
+247 Saint Bernard, St Bernard
+248 Eskimo dog, husky
+249 malamute, malemute, Alaskan malamute
+250 Siberian husky
+251 dalmatian, coach dog, carriage dog
+252 affenpinscher, monkey pinscher, monkey dog
+253 basenji
+254 pug, pug-dog
+255 Leonberg
+256 Newfoundland, Newfoundland dog
+257 Great Pyrenees
+258 Samoyed, Samoyede
+259 Pomeranian
+260 chow, chow chow
+261 keeshond
+262 Brabancon griffon
+263 Pembroke, Pembroke Welsh corgi
+264 Cardigan, Cardigan Welsh corgi
+265 toy poodle
+266 miniature poodle
+267 standard poodle
+268 Mexican hairless
+269 timber wolf, grey wolf, gray wolf, Canis lupus
+270 white wolf, Arctic wolf, Canis lupus tundrarum
+271 red wolf, maned wolf, Canis rufus, Canis niger
+272 coyote, prairie wolf, brush wolf, Canis latrans
+273 dingo, warrigal, warragal, Canis dingo
+274 dhole, Cuon alpinus
+275 African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus
+276 hyena, hyaena
+277 red fox, Vulpes vulpes
+278 kit fox, Vulpes macrotis
+279 Arctic fox, white fox, Alopex lagopus
+280 grey fox, gray fox, Urocyon cinereoargenteus
+281 tabby, tabby cat
+282 tiger cat
+283 Persian cat
+284 Siamese cat, Siamese
+285 Egyptian cat
+286 cougar, puma, catamount, mountain lion, painter, panther, Felis concolor
+287 lynx, catamount
+288 leopard, Panthera pardus
+289 snow leopard, ounce, Panthera uncia
+290 jaguar, panther, Panthera onca, Felis onca
+291 lion, king of beasts, Panthera leo
+292 tiger, Panthera tigris
+293 cheetah, chetah, Acinonyx jubatus
+294 brown bear, bruin, Ursus arctos
+295 American black bear, black bear, Ursus americanus, Euarctos americanus
+296 ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus
+297 sloth bear, Melursus ursinus, Ursus ursinus
+298 mongoose
+299 meerkat, mierkat
+300 tiger beetle
+301 ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle
+302 ground beetle, carabid beetle
+303 long-horned beetle, longicorn, longicorn beetle
+304 leaf beetle, chrysomelid
+305 dung beetle
+306 rhinoceros beetle
+307 weevil
+308 fly
+309 bee
+310 ant, emmet, pismire
+311 grasshopper, hopper
+312 cricket
+313 walking stick, walkingstick, stick insect
+314 cockroach, roach
+315 mantis, mantid
+316 cicada, cicala
+317 leafhopper
+318 lacewing, lacewing fly
+319 dragonfly, darning needle, devil's darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk
+320 damselfly
+321 admiral
+322 ringlet, ringlet butterfly
+323 monarch, monarch butterfly, milkweed butterfly, Danaus plexippus
+324 cabbage butterfly
+325 sulphur butterfly, sulfur butterfly
+326 lycaenid, lycaenid butterfly
+327 starfish, sea star
+328 sea urchin
+329 sea cucumber, holothurian
+330 wood rabbit, cottontail, cottontail rabbit
+331 hare
+332 Angora, Angora rabbit
+333 hamster
+334 porcupine, hedgehog
+335 fox squirrel, eastern fox squirrel, Sciurus niger
+336 marmot
+337 beaver
+338 guinea pig, Cavia cobaya
+339 sorrel
+340 zebra
+341 hog, pig, grunter, squealer, Sus scrofa
+342 wild boar, boar, Sus scrofa
+343 warthog
+344 hippopotamus, hippo, river horse, Hippopotamus amphibius
+345 ox
+346 water buffalo, water ox, Asiatic buffalo, Bubalus bubalis
+347 bison
+348 ram, tup
+349 bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis
+350 ibex, Capra ibex
+351 hartebeest
+352 impala, Aepyceros melampus
+353 gazelle
+354 Arabian camel, dromedary, Camelus dromedarius
+355 llama
+356 weasel
+357 mink
+358 polecat, fitch, foulmart, foumart, Mustela putorius
+359 black-footed ferret, ferret, Mustela nigripes
+360 otter
+361 skunk, polecat, wood pussy
+362 badger
+363 armadillo
+364 three-toed sloth, ai, Bradypus tridactylus
+365 orangutan, orang, orangutang, Pongo pygmaeus
+366 gorilla, Gorilla gorilla
+367 chimpanzee, chimp, Pan troglodytes
+368 gibbon, Hylobates lar
+369 siamang, Hylobates syndactylus, Symphalangus syndactylus
+370 guenon, guenon monkey
+371 patas, hussar monkey, Erythrocebus patas
+372 baboon
+373 macaque
+374 langur
+375 colobus, colobus monkey
+376 proboscis monkey, Nasalis larvatus
+377 marmoset
+378 capuchin, ringtail, Cebus capucinus
+379 howler monkey, howler
+380 titi, titi monkey
+381 spider monkey, Ateles geoffroyi
+382 squirrel monkey, Saimiri sciureus
+383 Madagascar cat, ring-tailed lemur, Lemur catta
+384 indri, indris, Indri indri, Indri brevicaudatus
+385 Indian elephant, Elephas maximus
+386 African elephant, Loxodonta africana
+387 lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens
+388 giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca
+389 barracouta, snoek
+390 eel
+391 coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus kisutch
+392 rock beauty, Holocanthus tricolor
+393 anemone fish
+394 sturgeon
+395 gar, garfish, garpike, billfish, Lepisosteus osseus
+396 lionfish
+397 puffer, pufferfish, blowfish, globefish
+398 abacus
+399 abaya
+400 academic gown, academic robe, judge's robe
+401 accordion, piano accordion, squeeze box
+402 acoustic guitar
+403 aircraft carrier, carrier, flattop, attack aircraft carrier
+404 airliner
+405 airship, dirigible
+406 altar
+407 ambulance
+408 amphibian, amphibious vehicle
+409 analog clock
+410 apiary, bee house
+411 apron
+412 ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin
+413 assault rifle, assault gun
+414 backpack, back pack, knapsack, packsack, rucksack, haversack
+415 bakery, bakeshop, bakehouse
+416 balance beam, beam
+417 balloon
+418 ballpoint, ballpoint pen, ballpen, Biro
+419 Band Aid
+420 banjo
+421 bannister, banister, balustrade, balusters, handrail
+422 barbell
+423 barber chair
+424 barbershop
+425 barn
+426 barometer
+427 barrel, cask
+428 barrow, garden cart, lawn cart, wheelbarrow
+429 baseball
+430 basketball
+431 bassinet
+432 bassoon
+433 bathing cap, swimming cap
+434 bath towel
+435 bathtub, bathing tub, bath, tub
+436 beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon
+437 beacon, lighthouse, beacon light, pharos
+438 beaker
+439 bearskin, busby, shako
+440 beer bottle
+441 beer glass
+442 bell cote, bell cot
+443 bib
+444 bicycle-built-for-two, tandem bicycle, tandem
+445 bikini, two-piece
+446 binder, ring-binder
+447 binoculars, field glasses, opera glasses
+448 birdhouse
+449 boathouse
+450 bobsled, bobsleigh, bob
+451 bolo tie, bolo, bola tie, bola
+452 bonnet, poke bonnet
+453 bookcase
+454 bookshop, bookstore, bookstall
+455 bottlecap
+456 bow
+457 bow tie, bow-tie, bowtie
+458 brass, memorial tablet, plaque
+459 brassiere, bra, bandeau
+460 breakwater, groin, groyne, mole, bulwark, seawall, jetty
+461 breastplate, aegis, egis
+462 broom
+463 bucket, pail
+464 buckle
+465 bulletproof vest
+466 bullet train, bullet
+467 butcher shop, meat market
+468 cab, hack, taxi, taxicab
+469 caldron, cauldron
+470 candle, taper, wax light
+471 cannon
+472 canoe
+473 can opener, tin opener
+474 cardigan
+475 car mirror
+476 carousel, carrousel, merry-go-round, roundabout, whirligig
+477 carpenter's kit, tool kit
+478 carton
+479 car wheel
+480 cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM
+481 cassette
+482 cassette player
+483 castle
+484 catamaran
+485 CD player
+486 cello, violoncello
+487 cellular telephone, cellular phone, cellphone, cell, mobile phone
+488 chain
+489 chainlink fence
+490 chain mail, ring mail, mail, chain armor, chain armour, ring armor, ring armour
+491 chain saw, chainsaw
+492 chest
+493 chiffonier, commode
+494 chime, bell, gong
+495 china cabinet, china closet
+496 Christmas stocking
+497 church, church building
+498 cinema, movie theater, movie theatre, movie house, picture palace
+499 cleaver, meat cleaver, chopper
+500 cliff dwelling
+501 cloak
+502 clog, geta, patten, sabot
+503 cocktail shaker
+504 coffee mug
+505 coffeepot
+506 coil, spiral, volute, whorl, helix
+507 combination lock
+508 computer keyboard, keypad
+509 confectionery, confectionary, candy store
+510 container ship, containership, container vessel
+511 convertible
+512 corkscrew, bottle screw
+513 cornet, horn, trumpet, trump
+514 cowboy boot
+515 cowboy hat, ten-gallon hat
+516 cradle
+517 crane
+518 crash helmet
+519 crate
+520 crib, cot
+521 Crock Pot
+522 croquet ball
+523 crutch
+524 cuirass
+525 dam, dike, dyke
+526 desk
+527 desktop computer
+528 dial telephone, dial phone
+529 diaper, nappy, napkin
+530 digital clock
+531 digital watch
+532 dining table, board
+533 dishrag, dishcloth
+534 dishwasher, dish washer, dishwashing machine
+535 disk brake, disc brake
+536 dock, dockage, docking facility
+537 dogsled, dog sled, dog sleigh
+538 dome
+539 doormat, welcome mat
+540 drilling platform, offshore rig
+541 drum, membranophone, tympan
+542 drumstick
+543 dumbbell
+544 Dutch oven
+545 electric fan, blower
+546 electric guitar
+547 electric locomotive
+548 entertainment center
+549 envelope
+550 espresso maker
+551 face powder
+552 feather boa, boa
+553 file, file cabinet, filing cabinet
+554 fireboat
+555 fire engine, fire truck
+556 fire screen, fireguard
+557 flagpole, flagstaff
+558 flute, transverse flute
+559 folding chair
+560 football helmet
+561 forklift
+562 fountain
+563 fountain pen
+564 four-poster
+565 freight car
+566 French horn, horn
+567 frying pan, frypan, skillet
+568 fur coat
+569 garbage truck, dustcart
+570 gasmask, respirator, gas helmet
+571 gas pump, gasoline pump, petrol pump, island dispenser
+572 goblet
+573 go-kart
+574 golf ball
+575 golfcart, golf cart
+576 gondola
+577 gong, tam-tam
+578 gown
+579 grand piano, grand
+580 greenhouse, nursery, glasshouse
+581 grille, radiator grille
+582 grocery store, grocery, food market, market
+583 guillotine
+584 hair slide
+585 hair spray
+586 half track
+587 hammer
+588 hamper
+589 hand blower, blow dryer, blow drier, hair dryer, hair drier
+590 hand-held computer, hand-held microcomputer
+591 handkerchief, hankie, hanky, hankey
+592 hard disc, hard disk, fixed disk
+593 harmonica, mouth organ, harp, mouth harp
+594 harp
+595 harvester, reaper
+596 hatchet
+597 holster
+598 home theater, home theatre
+599 honeycomb
+600 hook, claw
+601 hoopskirt, crinoline
+602 horizontal bar, high bar
+603 horse cart, horse-cart
+604 hourglass
+605 iPod
+606 iron, smoothing iron
+607 jack-o'-lantern
+608 jean, blue jean, denim
+609 jeep, landrover
+610 jersey, T-shirt, tee shirt
+611 jigsaw puzzle
+612 jinrikisha, ricksha, rickshaw
+613 joystick
+614 kimono
+615 knee pad
+616 knot
+617 lab coat, laboratory coat
+618 ladle
+619 lampshade, lamp shade
+620 laptop, laptop computer
+621 lawn mower, mower
+622 lens cap, lens cover
+623 letter opener, paper knife, paperknife
+624 library
+625 lifeboat
+626 lighter, light, igniter, ignitor
+627 limousine, limo
+628 liner, ocean liner
+629 lipstick, lip rouge
+630 Loafer
+631 lotion
+632 loudspeaker, speaker, speaker unit, loudspeaker system, speaker system
+633 loupe, jeweler's loupe
+634 lumbermill, sawmill
+635 magnetic compass
+636 mailbag, postbag
+637 mailbox, letter box
+638 maillot
+639 maillot, tank suit
+640 manhole cover
+641 maraca
+642 marimba, xylophone
+643 mask
+644 matchstick
+645 maypole
+646 maze, labyrinth
+647 measuring cup
+648 medicine chest, medicine cabinet
+649 megalith, megalithic structure
+650 microphone, mike
+651 microwave, microwave oven
+652 military uniform
+653 milk can
+654 minibus
+655 miniskirt, mini
+656 minivan
+657 missile
+658 mitten
+659 mixing bowl
+660 mobile home, manufactured home
+661 Model T
+662 modem
+663 monastery
+664 monitor
+665 moped
+666 mortar
+667 mortarboard
+668 mosque
+669 mosquito net
+670 motor scooter, scooter
+671 mountain bike, all-terrain bike, off-roader
+672 mountain tent
+673 mouse, computer mouse
+674 mousetrap
+675 moving van
+676 muzzle
+677 nail
+678 neck brace
+679 necklace
+680 nipple
+681 notebook, notebook computer
+682 obelisk
+683 oboe, hautboy, hautbois
+684 ocarina, sweet potato
+685 odometer, hodometer, mileometer, milometer
+686 oil filter
+687 organ, pipe organ
+688 oscilloscope, scope, cathode-ray oscilloscope, CRO
+689 overskirt
+690 oxcart
+691 oxygen mask
+692 packet
+693 paddle, boat paddle
+694 paddlewheel, paddle wheel
+695 padlock
+696 paintbrush
+697 pajama, pyjama, pj's, jammies
+698 palace
+699 panpipe, pandean pipe, syrinx
+700 paper towel
+701 parachute, chute
+702 parallel bars, bars
+703 park bench
+704 parking meter
+705 passenger car, coach, carriage
+706 patio, terrace
+707 pay-phone, pay-station
+708 pedestal, plinth, footstall
+709 pencil box, pencil case
+710 pencil sharpener
+711 perfume, essence
+712 Petri dish
+713 photocopier
+714 pick, plectrum, plectron
+715 pickelhaube
+716 picket fence, paling
+717 pickup, pickup truck
+718 pier
+719 piggy bank, penny bank
+720 pill bottle
+721 pillow
+722 ping-pong ball
+723 pinwheel
+724 pirate, pirate ship
+725 pitcher, ewer
+726 plane, carpenter's plane, woodworking plane
+727 planetarium
+728 plastic bag
+729 plate rack
+730 plow, plough
+731 plunger, plumber's helper
+732 Polaroid camera, Polaroid Land camera
+733 pole
+734 police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria
+735 poncho
+736 pool table, billiard table, snooker table
+737 pop bottle, soda bottle
+738 pot, flowerpot
+739 potter's wheel
+740 power drill
+741 prayer rug, prayer mat
+742 printer
+743 prison, prison house
+744 projectile, missile
+745 projector
+746 puck, hockey puck
+747 punching bag, punch bag, punching ball, punchball
+748 purse
+749 quill, quill pen
+750 quilt, comforter, comfort, puff
+751 racer, race car, racing car
+752 racket, racquet
+753 radiator
+754 radio, wireless
+755 radio telescope, radio reflector
+756 rain barrel
+757 recreational vehicle, RV, R.V.
+758 reel
+759 reflex camera
+760 refrigerator, icebox
+761 remote control, remote
+762 restaurant, eating house, eating place, eatery
+763 revolver, six-gun, six-shooter
+764 rifle
+765 rocking chair, rocker
+766 rotisserie
+767 rubber eraser, rubber, pencil eraser
+768 rugby ball
+769 rule, ruler
+770 running shoe
+771 safe
+772 safety pin
+773 saltshaker, salt shaker
+774 sandal
+775 sarong
+776 sax, saxophone
+777 scabbard
+778 scale, weighing machine
+779 school bus
+780 schooner
+781 scoreboard
+782 screen, CRT screen
+783 screw
+784 screwdriver
+785 seat belt, seatbelt
+786 sewing machine
+787 shield, buckler
+788 shoe shop, shoe-shop, shoe store
+789 shoji
+790 shopping basket
+791 shopping cart
+792 shovel
+793 shower cap
+794 shower curtain
+795 ski
+796 ski mask
+797 sleeping bag
+798 slide rule, slipstick
+799 sliding door
+800 slot, one-armed bandit
+801 snorkel
+802 snowmobile
+803 snowplow, snowplough
+804 soap dispenser
+805 soccer ball
+806 sock
+807 solar dish, solar collector, solar furnace
+808 sombrero
+809 soup bowl
+810 space bar
+811 space heater
+812 space shuttle
+813 spatula
+814 speedboat
+815 spider web, spider's web
+816 spindle
+817 sports car, sport car
+818 spotlight, spot
+819 stage
+820 steam locomotive
+821 steel arch bridge
+822 steel drum
+823 stethoscope
+824 stole
+825 stone wall
+826 stopwatch, stop watch
+827 stove
+828 strainer
+829 streetcar, tram, tramcar, trolley, trolley car
+830 stretcher
+831 studio couch, day bed
+832 stupa, tope
+833 submarine, pigboat, sub, U-boat
+834 suit, suit of clothes
+835 sundial
+836 sunglass
+837 sunglasses, dark glasses, shades
+838 sunscreen, sunblock, sun blocker
+839 suspension bridge
+840 swab, swob, mop
+841 sweatshirt
+842 swimming trunks, bathing trunks
+843 swing
+844 switch, electric switch, electrical switch
+845 syringe
+846 table lamp
+847 tank, army tank, armored combat vehicle, armoured combat vehicle
+848 tape player
+849 teapot
+850 teddy, teddy bear
+851 television, television system
+852 tennis ball
+853 thatch, thatched roof
+854 theater curtain, theatre curtain
+855 thimble
+856 thresher, thrasher, threshing machine
+857 throne
+858 tile roof
+859 toaster
+860 tobacco shop, tobacconist shop, tobacconist
+861 toilet seat
+862 torch
+863 totem pole
+864 tow truck, tow car, wrecker
+865 toyshop
+866 tractor
+867 trailer truck, tractor trailer, trucking rig, rig, articulated lorry, semi
+868 tray
+869 trench coat
+870 tricycle, trike, velocipede
+871 trimaran
+872 tripod
+873 triumphal arch
+874 trolleybus, trolley coach, trackless trolley
+875 trombone
+876 tub, vat
+877 turnstile
+878 typewriter keyboard
+879 umbrella
+880 unicycle, monocycle
+881 upright, upright piano
+882 vacuum, vacuum cleaner
+883 vase
+884 vault
+885 velvet
+886 vending machine
+887 vestment
+888 viaduct
+889 violin, fiddle
+890 volleyball
+891 waffle iron
+892 wall clock
+893 wallet, billfold, notecase, pocketbook
+894 wardrobe, closet, press
+895 warplane, military plane
+896 washbasin, handbasin, washbowl, lavabo, wash-hand basin
+897 washer, automatic washer, washing machine
+898 water bottle
+899 water jug
+900 water tower
+901 whiskey jug
+902 whistle
+903 wig
+904 window screen
+905 window shade
+906 Windsor tie
+907 wine bottle
+908 wing
+909 wok
+910 wooden spoon
+911 wool, woolen, woollen
+912 worm fence, snake fence, snake-rail fence, Virginia fence
+913 wreck
+914 yawl
+915 yurt
+916 web site, website, internet site, site
+917 comic book
+918 crossword puzzle, crossword
+919 street sign
+920 traffic light, traffic signal, stoplight
+921 book jacket, dust cover, dust jacket, dust wrapper
+922 menu
+923 plate
+924 guacamole
+925 consomme
+926 hot pot, hotpot
+927 trifle
+928 ice cream, icecream
+929 ice lolly, lolly, lollipop, popsicle
+930 French loaf
+931 bagel, beigel
+932 pretzel
+933 cheeseburger
+934 hotdog, hot dog, red hot
+935 mashed potato
+936 head cabbage
+937 broccoli
+938 cauliflower
+939 zucchini, courgette
+940 spaghetti squash
+941 acorn squash
+942 butternut squash
+943 cucumber, cuke
+944 artichoke, globe artichoke
+945 bell pepper
+946 cardoon
+947 mushroom
+948 Granny Smith
+949 strawberry
+950 orange
+951 lemon
+952 fig
+953 pineapple, ananas
+954 banana
+955 jackfruit, jak, jack
+956 custard apple
+957 pomegranate
+958 hay
+959 carbonara
+960 chocolate sauce, chocolate syrup
+961 dough
+962 meat loaf, meatloaf
+963 pizza, pizza pie
+964 potpie
+965 burrito
+966 red wine
+967 espresso
+968 cup
+969 eggnog
+970 alp
+971 bubble
+972 cliff, drop, drop-off
+973 coral reef
+974 geyser
+975 lakeside, lakeshore
+976 promontory, headland, head, foreland
+977 sandbar, sand bar
+978 seashore, coast, seacoast, sea-coast
+979 valley, vale
+980 volcano
+981 ballplayer, baseball player
+982 groom, bridegroom
+983 scuba diver
+984 rapeseed
+985 daisy
+986 yellow lady's slipper, yellow lady-slipper, Cypripedium calceolus, Cypripedium parviflorum
+987 corn
+988 acorn
+989 hip, rose hip, rosehip
+990 buckeye, horse chestnut, conker
+991 coral fungus
+992 agaric
+993 gyromitra
+994 stinkhorn, carrion fungus
+995 earthstar
+996 hen-of-the-woods, hen of the woods, Polyporus frondosus, Grifola frondosa
+997 bolete
+998 ear, spike, capitulum
+999 toilet tissue, toilet paper, bathroom tissue
diff --git a/pytorch_grad_cam/__init__.py b/pytorch_grad_cam/__init__.py
new file mode 100644
index 0000000..4d6e8f3
--- /dev/null
+++ b/pytorch_grad_cam/__init__.py
@@ -0,0 +1,20 @@
+from pytorch_grad_cam.grad_cam import GradCAM
+from pytorch_grad_cam.hirescam import HiResCAM
+from pytorch_grad_cam.grad_cam_elementwise import GradCAMElementWise
+from pytorch_grad_cam.ablation_layer import AblationLayer, AblationLayerVit, AblationLayerFasterRCNN
+from pytorch_grad_cam.ablation_cam import AblationCAM
+from pytorch_grad_cam.xgrad_cam import XGradCAM
+from pytorch_grad_cam.grad_cam_plusplus import GradCAMPlusPlus
+from pytorch_grad_cam.score_cam import ScoreCAM
+from pytorch_grad_cam.layer_cam import LayerCAM
+from pytorch_grad_cam.eigen_cam import EigenCAM
+from pytorch_grad_cam.eigen_grad_cam import EigenGradCAM
+from pytorch_grad_cam.random_cam import RandomCAM
+from pytorch_grad_cam.fullgrad_cam import FullGrad
+from pytorch_grad_cam.guided_backprop import GuidedBackpropReLUModel
+from pytorch_grad_cam.activations_and_gradients import ActivationsAndGradients
+from pytorch_grad_cam.feature_factorization.deep_feature_factorization import DeepFeatureFactorization, run_dff_on_image
+import pytorch_grad_cam.utils.model_targets
+import pytorch_grad_cam.utils.reshape_transforms
+import pytorch_grad_cam.metrics.cam_mult_image
+import pytorch_grad_cam.metrics.road
diff --git a/pytorch_grad_cam/__pycache__/__init__.cpython-37.pyc b/pytorch_grad_cam/__pycache__/__init__.cpython-37.pyc
new file mode 100644
index 0000000..f550765
Binary files /dev/null and b/pytorch_grad_cam/__pycache__/__init__.cpython-37.pyc differ
diff --git a/pytorch_grad_cam/__pycache__/__init__.cpython-38.pyc b/pytorch_grad_cam/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000..556d829
Binary files /dev/null and b/pytorch_grad_cam/__pycache__/__init__.cpython-38.pyc differ
diff --git a/pytorch_grad_cam/__pycache__/ablation_cam.cpython-37.pyc b/pytorch_grad_cam/__pycache__/ablation_cam.cpython-37.pyc
new file mode 100644
index 0000000..a301e45
Binary files /dev/null and b/pytorch_grad_cam/__pycache__/ablation_cam.cpython-37.pyc differ
diff --git a/pytorch_grad_cam/__pycache__/ablation_cam.cpython-38.pyc b/pytorch_grad_cam/__pycache__/ablation_cam.cpython-38.pyc
new file mode 100644
index 0000000..caed0c2
Binary files /dev/null and b/pytorch_grad_cam/__pycache__/ablation_cam.cpython-38.pyc differ
diff --git a/pytorch_grad_cam/__pycache__/ablation_layer.cpython-37.pyc b/pytorch_grad_cam/__pycache__/ablation_layer.cpython-37.pyc
new file mode 100644
index 0000000..04f5a5c
Binary files /dev/null and b/pytorch_grad_cam/__pycache__/ablation_layer.cpython-37.pyc differ
diff --git a/pytorch_grad_cam/__pycache__/ablation_layer.cpython-38.pyc b/pytorch_grad_cam/__pycache__/ablation_layer.cpython-38.pyc
new file mode 100644
index 0000000..4ffd254
Binary files /dev/null and b/pytorch_grad_cam/__pycache__/ablation_layer.cpython-38.pyc differ
diff --git a/pytorch_grad_cam/__pycache__/activations_and_gradients.cpython-37.pyc b/pytorch_grad_cam/__pycache__/activations_and_gradients.cpython-37.pyc
new file mode 100644
index 0000000..ea35268
Binary files /dev/null and b/pytorch_grad_cam/__pycache__/activations_and_gradients.cpython-37.pyc differ
diff --git a/pytorch_grad_cam/__pycache__/activations_and_gradients.cpython-38.pyc b/pytorch_grad_cam/__pycache__/activations_and_gradients.cpython-38.pyc
new file mode 100644
index 0000000..d74710b
Binary files /dev/null and b/pytorch_grad_cam/__pycache__/activations_and_gradients.cpython-38.pyc differ
diff --git a/pytorch_grad_cam/__pycache__/base_cam.cpython-37.pyc b/pytorch_grad_cam/__pycache__/base_cam.cpython-37.pyc
new file mode 100644
index 0000000..b7e8d0c
Binary files /dev/null and b/pytorch_grad_cam/__pycache__/base_cam.cpython-37.pyc differ
diff --git a/pytorch_grad_cam/__pycache__/base_cam.cpython-38.pyc b/pytorch_grad_cam/__pycache__/base_cam.cpython-38.pyc
new file mode 100644
index 0000000..6c60dff
Binary files /dev/null and b/pytorch_grad_cam/__pycache__/base_cam.cpython-38.pyc differ
diff --git a/pytorch_grad_cam/__pycache__/eigen_cam.cpython-37.pyc b/pytorch_grad_cam/__pycache__/eigen_cam.cpython-37.pyc
new file mode 100644
index 0000000..113285b
Binary files /dev/null and b/pytorch_grad_cam/__pycache__/eigen_cam.cpython-37.pyc differ
diff --git a/pytorch_grad_cam/__pycache__/eigen_cam.cpython-38.pyc b/pytorch_grad_cam/__pycache__/eigen_cam.cpython-38.pyc
new file mode 100644
index 0000000..8a9ab14
Binary files /dev/null and b/pytorch_grad_cam/__pycache__/eigen_cam.cpython-38.pyc differ
diff --git a/pytorch_grad_cam/__pycache__/eigen_grad_cam.cpython-37.pyc b/pytorch_grad_cam/__pycache__/eigen_grad_cam.cpython-37.pyc
new file mode 100644
index 0000000..f325b77
Binary files /dev/null and b/pytorch_grad_cam/__pycache__/eigen_grad_cam.cpython-37.pyc differ
diff --git a/pytorch_grad_cam/__pycache__/eigen_grad_cam.cpython-38.pyc b/pytorch_grad_cam/__pycache__/eigen_grad_cam.cpython-38.pyc
new file mode 100644
index 0000000..bae5a3d
Binary files /dev/null and b/pytorch_grad_cam/__pycache__/eigen_grad_cam.cpython-38.pyc differ
diff --git a/pytorch_grad_cam/__pycache__/fullgrad_cam.cpython-37.pyc b/pytorch_grad_cam/__pycache__/fullgrad_cam.cpython-37.pyc
new file mode 100644
index 0000000..5105c8a
Binary files /dev/null and b/pytorch_grad_cam/__pycache__/fullgrad_cam.cpython-37.pyc differ
diff --git a/pytorch_grad_cam/__pycache__/fullgrad_cam.cpython-38.pyc b/pytorch_grad_cam/__pycache__/fullgrad_cam.cpython-38.pyc
new file mode 100644
index 0000000..2cb48a8
Binary files /dev/null and b/pytorch_grad_cam/__pycache__/fullgrad_cam.cpython-38.pyc differ
diff --git a/pytorch_grad_cam/__pycache__/grad_cam.cpython-37.pyc b/pytorch_grad_cam/__pycache__/grad_cam.cpython-37.pyc
new file mode 100644
index 0000000..e64d058
Binary files /dev/null and b/pytorch_grad_cam/__pycache__/grad_cam.cpython-37.pyc differ
diff --git a/pytorch_grad_cam/__pycache__/grad_cam.cpython-38.pyc b/pytorch_grad_cam/__pycache__/grad_cam.cpython-38.pyc
new file mode 100644
index 0000000..b3c13fe
Binary files /dev/null and b/pytorch_grad_cam/__pycache__/grad_cam.cpython-38.pyc differ
diff --git a/pytorch_grad_cam/__pycache__/grad_cam_elementwise.cpython-37.pyc b/pytorch_grad_cam/__pycache__/grad_cam_elementwise.cpython-37.pyc
new file mode 100644
index 0000000..2482d60
Binary files /dev/null and b/pytorch_grad_cam/__pycache__/grad_cam_elementwise.cpython-37.pyc differ
diff --git a/pytorch_grad_cam/__pycache__/grad_cam_elementwise.cpython-38.pyc b/pytorch_grad_cam/__pycache__/grad_cam_elementwise.cpython-38.pyc
new file mode 100644
index 0000000..3daee8a
Binary files /dev/null and b/pytorch_grad_cam/__pycache__/grad_cam_elementwise.cpython-38.pyc differ
diff --git a/pytorch_grad_cam/__pycache__/grad_cam_plusplus.cpython-37.pyc b/pytorch_grad_cam/__pycache__/grad_cam_plusplus.cpython-37.pyc
new file mode 100644
index 0000000..e8b7d1e
Binary files /dev/null and b/pytorch_grad_cam/__pycache__/grad_cam_plusplus.cpython-37.pyc differ
diff --git a/pytorch_grad_cam/__pycache__/grad_cam_plusplus.cpython-38.pyc b/pytorch_grad_cam/__pycache__/grad_cam_plusplus.cpython-38.pyc
new file mode 100644
index 0000000..988e95f
Binary files /dev/null and b/pytorch_grad_cam/__pycache__/grad_cam_plusplus.cpython-38.pyc differ
diff --git a/pytorch_grad_cam/__pycache__/guided_backprop.cpython-37.pyc b/pytorch_grad_cam/__pycache__/guided_backprop.cpython-37.pyc
new file mode 100644
index 0000000..03571a1
Binary files /dev/null and b/pytorch_grad_cam/__pycache__/guided_backprop.cpython-37.pyc differ
diff --git a/pytorch_grad_cam/__pycache__/guided_backprop.cpython-38.pyc b/pytorch_grad_cam/__pycache__/guided_backprop.cpython-38.pyc
new file mode 100644
index 0000000..b9ca333
Binary files /dev/null and b/pytorch_grad_cam/__pycache__/guided_backprop.cpython-38.pyc differ
diff --git a/pytorch_grad_cam/__pycache__/hirescam.cpython-37.pyc b/pytorch_grad_cam/__pycache__/hirescam.cpython-37.pyc
new file mode 100644
index 0000000..4e95b72
Binary files /dev/null and b/pytorch_grad_cam/__pycache__/hirescam.cpython-37.pyc differ
diff --git a/pytorch_grad_cam/__pycache__/hirescam.cpython-38.pyc b/pytorch_grad_cam/__pycache__/hirescam.cpython-38.pyc
new file mode 100644
index 0000000..680e12e
Binary files /dev/null and b/pytorch_grad_cam/__pycache__/hirescam.cpython-38.pyc differ
diff --git a/pytorch_grad_cam/__pycache__/layer_cam.cpython-37.pyc b/pytorch_grad_cam/__pycache__/layer_cam.cpython-37.pyc
new file mode 100644
index 0000000..e47bddf
Binary files /dev/null and b/pytorch_grad_cam/__pycache__/layer_cam.cpython-37.pyc differ
diff --git a/pytorch_grad_cam/__pycache__/layer_cam.cpython-38.pyc b/pytorch_grad_cam/__pycache__/layer_cam.cpython-38.pyc
new file mode 100644
index 0000000..30022f9
Binary files /dev/null and b/pytorch_grad_cam/__pycache__/layer_cam.cpython-38.pyc differ
diff --git a/pytorch_grad_cam/__pycache__/random_cam.cpython-37.pyc b/pytorch_grad_cam/__pycache__/random_cam.cpython-37.pyc
new file mode 100644
index 0000000..70ed073
Binary files /dev/null and b/pytorch_grad_cam/__pycache__/random_cam.cpython-37.pyc differ
diff --git a/pytorch_grad_cam/__pycache__/random_cam.cpython-38.pyc b/pytorch_grad_cam/__pycache__/random_cam.cpython-38.pyc
new file mode 100644
index 0000000..dfcaa56
Binary files /dev/null and b/pytorch_grad_cam/__pycache__/random_cam.cpython-38.pyc differ
diff --git a/pytorch_grad_cam/__pycache__/score_cam.cpython-37.pyc b/pytorch_grad_cam/__pycache__/score_cam.cpython-37.pyc
new file mode 100644
index 0000000..0c6e7f2
Binary files /dev/null and b/pytorch_grad_cam/__pycache__/score_cam.cpython-37.pyc differ
diff --git a/pytorch_grad_cam/__pycache__/score_cam.cpython-38.pyc b/pytorch_grad_cam/__pycache__/score_cam.cpython-38.pyc
new file mode 100644
index 0000000..9b4f8a1
Binary files /dev/null and b/pytorch_grad_cam/__pycache__/score_cam.cpython-38.pyc differ
diff --git a/pytorch_grad_cam/__pycache__/xgrad_cam.cpython-37.pyc b/pytorch_grad_cam/__pycache__/xgrad_cam.cpython-37.pyc
new file mode 100644
index 0000000..1f3946d
Binary files /dev/null and b/pytorch_grad_cam/__pycache__/xgrad_cam.cpython-37.pyc differ
diff --git a/pytorch_grad_cam/__pycache__/xgrad_cam.cpython-38.pyc b/pytorch_grad_cam/__pycache__/xgrad_cam.cpython-38.pyc
new file mode 100644
index 0000000..ed2585a
Binary files /dev/null and b/pytorch_grad_cam/__pycache__/xgrad_cam.cpython-38.pyc differ
diff --git a/pytorch_grad_cam/ablation_cam.py b/pytorch_grad_cam/ablation_cam.py
new file mode 100644
index 0000000..77e65fc
--- /dev/null
+++ b/pytorch_grad_cam/ablation_cam.py
@@ -0,0 +1,148 @@
+import numpy as np
+import torch
+import tqdm
+from typing import Callable, List
+from pytorch_grad_cam.base_cam import BaseCAM
+from pytorch_grad_cam.utils.find_layers import replace_layer_recursive
+from pytorch_grad_cam.ablation_layer import AblationLayer
+
+
+""" Implementation of AblationCAM
+https://openaccess.thecvf.com/content_WACV_2020/papers/Desai_Ablation-CAM_Visual_Explanations_for_Deep_Convolutional_Network_via_Gradient-free_Localization_WACV_2020_paper.pdf
+
+Ablate individual activations, and then measure the drop in the target score.
+
+In the current implementation, the target layer activations is cached, so it won't be re-computed.
+However layers before it, if any, will not be cached.
+This means that if the target layer is a large block, for example model.featuers (in vgg), there will
+be a large save in run time.
+
+Since we have to go over many channels and ablate them, and every channel ablation requires a forward pass,
+it would be nice if we could avoid doing that for channels that won't contribute anwyay, making it much faster.
+The parameter ratio_channels_to_ablate controls how many channels should be ablated, using an experimental method
+(to be improved). The default 1.0 value means that all channels will be ablated.
+"""
+
+
+class AblationCAM(BaseCAM):
+ def __init__(self,
+ model: torch.nn.Module,
+ target_layers: List[torch.nn.Module],
+ use_cuda: bool = False,
+ reshape_transform: Callable = None,
+ ablation_layer: torch.nn.Module = AblationLayer(),
+ batch_size: int = 32,
+ ratio_channels_to_ablate: float = 1.0) -> None:
+
+ super(AblationCAM, self).__init__(model,
+ target_layers,
+ use_cuda,
+ reshape_transform,
+ uses_gradients=False)
+ self.batch_size = batch_size
+ self.ablation_layer = ablation_layer
+ self.ratio_channels_to_ablate = ratio_channels_to_ablate
+
+ def save_activation(self, module, input, output) -> None:
+ """ Helper function to save the raw activations from the target layer """
+ self.activations = output
+
+ def assemble_ablation_scores(self,
+ new_scores: list,
+ original_score: float,
+ ablated_channels: np.ndarray,
+ number_of_channels: int) -> np.ndarray:
+ """ Take the value from the channels that were ablated,
+ and just set the original score for the channels that were skipped """
+
+ index = 0
+ result = []
+ sorted_indices = np.argsort(ablated_channels)
+ ablated_channels = ablated_channels[sorted_indices]
+ new_scores = np.float32(new_scores)[sorted_indices]
+
+ for i in range(number_of_channels):
+ if index < len(ablated_channels) and ablated_channels[index] == i:
+ weight = new_scores[index]
+ index = index + 1
+ else:
+ weight = original_score
+ result.append(weight)
+
+ return result
+
+ def get_cam_weights(self,
+ input_tensor: torch.Tensor,
+ target_layer: torch.nn.Module,
+ targets: List[Callable],
+ activations: torch.Tensor,
+ grads: torch.Tensor) -> np.ndarray:
+
+ # Do a forward pass, compute the target scores, and cache the
+ # activations
+ handle = target_layer.register_forward_hook(self.save_activation)
+ with torch.no_grad():
+ outputs = self.model(input_tensor)
+ handle.remove()
+ original_scores = np.float32(
+ [target(output).cpu().item() for target, output in zip(targets, outputs)])
+
+ # Replace the layer with the ablation layer.
+ # When we finish, we will replace it back, so the original model is
+ # unchanged.
+ ablation_layer = self.ablation_layer
+ replace_layer_recursive(self.model, target_layer, ablation_layer)
+
+ number_of_channels = activations.shape[1]
+ weights = []
+ # This is a "gradient free" method, so we don't need gradients here.
+ with torch.no_grad():
+ # Loop over each of the batch images and ablate activations for it.
+ for batch_index, (target, tensor) in enumerate(
+ zip(targets, input_tensor)):
+ new_scores = []
+ batch_tensor = tensor.repeat(self.batch_size, 1, 1, 1)
+
+ # Check which channels should be ablated. Normally this will be all channels,
+ # But we can also try to speed this up by using a low
+ # ratio_channels_to_ablate.
+ channels_to_ablate = ablation_layer.activations_to_be_ablated(
+ activations[batch_index, :], self.ratio_channels_to_ablate)
+ number_channels_to_ablate = len(channels_to_ablate)
+
+ for i in tqdm.tqdm(
+ range(
+ 0,
+ number_channels_to_ablate,
+ self.batch_size)):
+ if i + self.batch_size > number_channels_to_ablate:
+ batch_tensor = batch_tensor[:(
+ number_channels_to_ablate - i)]
+
+ # Change the state of the ablation layer so it ablates the next channels.
+ # TBD: Move this into the ablation layer forward pass.
+ ablation_layer.set_next_batch(
+ input_batch_index=batch_index,
+ activations=self.activations,
+ num_channels_to_ablate=batch_tensor.size(0))
+ score = [target(o).cpu().item()
+ for o in self.model(batch_tensor)]
+ new_scores.extend(score)
+ ablation_layer.indices = ablation_layer.indices[batch_tensor.size(
+ 0):]
+
+ new_scores = self.assemble_ablation_scores(
+ new_scores,
+ original_scores[batch_index],
+ channels_to_ablate,
+ number_of_channels)
+ weights.extend(new_scores)
+
+ weights = np.float32(weights)
+ weights = weights.reshape(activations.shape[:2])
+ original_scores = original_scores[:, None]
+ weights = (original_scores - weights) / original_scores
+
+ # Replace the model back to the original state
+ replace_layer_recursive(self.model, ablation_layer, target_layer)
+ return weights
diff --git a/pytorch_grad_cam/ablation_cam_multilayer.py b/pytorch_grad_cam/ablation_cam_multilayer.py
new file mode 100644
index 0000000..9b9dc80
--- /dev/null
+++ b/pytorch_grad_cam/ablation_cam_multilayer.py
@@ -0,0 +1,136 @@
+import cv2
+import numpy as np
+import torch
+import tqdm
+from pytorch_grad_cam.base_cam import BaseCAM
+
+
+class AblationLayer(torch.nn.Module):
+ def __init__(self, layer, reshape_transform, indices):
+ super(AblationLayer, self).__init__()
+
+ self.layer = layer
+ self.reshape_transform = reshape_transform
+ # The channels to zero out:
+ self.indices = indices
+
+ def forward(self, x):
+ self.__call__(x)
+
+ def __call__(self, x):
+ output = self.layer(x)
+
+ # Hack to work with ViT,
+ # Since the activation channels are last and not first like in CNNs
+ # Probably should remove it?
+ if self.reshape_transform is not None:
+ output = output.transpose(1, 2)
+
+ for i in range(output.size(0)):
+
+ # Commonly the minimum activation will be 0,
+ # And then it makes sense to zero it out.
+ # However depending on the architecture,
+ # If the values can be negative, we use very negative values
+ # to perform the ablation, deviating from the paper.
+ if torch.min(output) == 0:
+ output[i, self.indices[i], :] = 0
+ else:
+ ABLATION_VALUE = 1e5
+ output[i, self.indices[i], :] = torch.min(
+ output) - ABLATION_VALUE
+
+ if self.reshape_transform is not None:
+ output = output.transpose(2, 1)
+
+ return output
+
+
+def replace_layer_recursive(model, old_layer, new_layer):
+ for name, layer in model._modules.items():
+ if layer == old_layer:
+ model._modules[name] = new_layer
+ return True
+ elif replace_layer_recursive(layer, old_layer, new_layer):
+ return True
+ return False
+
+
+class AblationCAM(BaseCAM):
+ def __init__(self, model, target_layers, use_cuda=False,
+ reshape_transform=None):
+ super(AblationCAM, self).__init__(model, target_layers, use_cuda,
+ reshape_transform)
+
+ if len(target_layers) > 1:
+ print(
+ "Warning. You are usign Ablation CAM with more than 1 layers. "
+ "This is supported only if all layers have the same output shape")
+
+ def set_ablation_layers(self):
+ self.ablation_layers = []
+ for target_layer in self.target_layers:
+ ablation_layer = AblationLayer(target_layer,
+ self.reshape_transform, indices=[])
+ self.ablation_layers.append(ablation_layer)
+ replace_layer_recursive(self.model, target_layer, ablation_layer)
+
+ def unset_ablation_layers(self):
+ # replace the model back to the original state
+ for ablation_layer, target_layer in zip(
+ self.ablation_layers, self.target_layers):
+ replace_layer_recursive(self.model, ablation_layer, target_layer)
+
+ def set_ablation_layer_batch_indices(self, indices):
+ for ablation_layer in self.ablation_layers:
+ ablation_layer.indices = indices
+
+ def trim_ablation_layer_batch_indices(self, keep):
+ for ablation_layer in self.ablation_layers:
+ ablation_layer.indices = ablation_layer.indices[:keep]
+
+ def get_cam_weights(self,
+ input_tensor,
+ target_category,
+ activations,
+ grads):
+ with torch.no_grad():
+ outputs = self.model(input_tensor).cpu().numpy()
+ original_scores = []
+ for i in range(input_tensor.size(0)):
+ original_scores.append(outputs[i, target_category[i]])
+ original_scores = np.float32(original_scores)
+
+ self.set_ablation_layers()
+
+ if hasattr(self, "batch_size"):
+ BATCH_SIZE = self.batch_size
+ else:
+ BATCH_SIZE = 32
+
+ number_of_channels = activations.shape[1]
+ weights = []
+
+ with torch.no_grad():
+ # Iterate over the input batch
+ for tensor, category in zip(input_tensor, target_category):
+ batch_tensor = tensor.repeat(BATCH_SIZE, 1, 1, 1)
+ for i in tqdm.tqdm(range(0, number_of_channels, BATCH_SIZE)):
+ self.set_ablation_layer_batch_indices(
+ list(range(i, i + BATCH_SIZE)))
+
+ if i + BATCH_SIZE > number_of_channels:
+ keep = number_of_channels - i
+ batch_tensor = batch_tensor[:keep]
+ self.trim_ablation_layer_batch_indices(self, keep)
+ score = self.model(batch_tensor)[:, category].cpu().numpy()
+ weights.extend(score)
+
+ weights = np.float32(weights)
+ weights = weights.reshape(activations.shape[:2])
+ original_scores = original_scores[:, None]
+ weights = (original_scores - weights) / original_scores
+
+ # replace the model back to the original state
+ self.unset_ablation_layers()
+ return weights
diff --git a/pytorch_grad_cam/ablation_layer.py b/pytorch_grad_cam/ablation_layer.py
new file mode 100644
index 0000000..b404f3b
--- /dev/null
+++ b/pytorch_grad_cam/ablation_layer.py
@@ -0,0 +1,155 @@
+import torch
+from collections import OrderedDict
+import numpy as np
+from pytorch_grad_cam.utils.svd_on_activations import get_2d_projection
+
+
+class AblationLayer(torch.nn.Module):
+ def __init__(self):
+ super(AblationLayer, self).__init__()
+
+ def objectiveness_mask_from_svd(self, activations, threshold=0.01):
+ """ Experimental method to get a binary mask to compare if the activation is worth ablating.
+ The idea is to apply the EigenCAM method by doing PCA on the activations.
+ Then we create a binary mask by comparing to a low threshold.
+ Areas that are masked out, are probably not interesting anyway.
+ """
+
+ projection = get_2d_projection(activations[None, :])[0, :]
+ projection = np.abs(projection)
+ projection = projection - projection.min()
+ projection = projection / projection.max()
+ projection = projection > threshold
+ return projection
+
+ def activations_to_be_ablated(
+ self,
+ activations,
+ ratio_channels_to_ablate=1.0):
+ """ Experimental method to get a binary mask to compare if the activation is worth ablating.
+ Create a binary CAM mask with objectiveness_mask_from_svd.
+ Score each Activation channel, by seeing how much of its values are inside the mask.
+ Then keep the top channels.
+
+ """
+ if ratio_channels_to_ablate == 1.0:
+ self.indices = np.int32(range(activations.shape[0]))
+ return self.indices
+
+ projection = self.objectiveness_mask_from_svd(activations)
+
+ scores = []
+ for channel in activations:
+ normalized = np.abs(channel)
+ normalized = normalized - normalized.min()
+ normalized = normalized / np.max(normalized)
+ score = (projection * normalized).sum() / normalized.sum()
+ scores.append(score)
+ scores = np.float32(scores)
+
+ indices = list(np.argsort(scores))
+ high_score_indices = indices[::-
+ 1][: int(len(indices) *
+ ratio_channels_to_ablate)]
+ low_score_indices = indices[: int(
+ len(indices) * ratio_channels_to_ablate)]
+ self.indices = np.int32(high_score_indices + low_score_indices)
+ return self.indices
+
+ def set_next_batch(
+ self,
+ input_batch_index,
+ activations,
+ num_channels_to_ablate):
+ """ This creates the next batch of activations from the layer.
+ Just take corresponding batch member from activations, and repeat it num_channels_to_ablate times.
+ """
+ self.activations = activations[input_batch_index, :, :, :].clone(
+ ).unsqueeze(0).repeat(num_channels_to_ablate, 1, 1, 1)
+
+ def __call__(self, x):
+ output = self.activations
+ for i in range(output.size(0)):
+ # Commonly the minimum activation will be 0,
+ # And then it makes sense to zero it out.
+ # However depending on the architecture,
+ # If the values can be negative, we use very negative values
+ # to perform the ablation, deviating from the paper.
+ if torch.min(output) == 0:
+ output[i, self.indices[i], :] = 0
+ else:
+ ABLATION_VALUE = 1e7
+ output[i, self.indices[i], :] = torch.min(
+ output) - ABLATION_VALUE
+
+ return output
+
+
+class AblationLayerVit(AblationLayer):
+ def __init__(self):
+ super(AblationLayerVit, self).__init__()
+
+ def __call__(self, x):
+ output = self.activations
+ output = output.transpose(1, len(output.shape) - 1)
+ for i in range(output.size(0)):
+
+ # Commonly the minimum activation will be 0,
+ # And then it makes sense to zero it out.
+ # However depending on the architecture,
+ # If the values can be negative, we use very negative values
+ # to perform the ablation, deviating from the paper.
+ if torch.min(output) == 0:
+ output[i, self.indices[i], :] = 0
+ else:
+ ABLATION_VALUE = 1e7
+ output[i, self.indices[i], :] = torch.min(
+ output) - ABLATION_VALUE
+
+ output = output.transpose(len(output.shape) - 1, 1)
+
+ return output
+
+ def set_next_batch(
+ self,
+ input_batch_index,
+ activations,
+ num_channels_to_ablate):
+ """ This creates the next batch of activations from the layer.
+ Just take corresponding batch member from activations, and repeat it num_channels_to_ablate times.
+ """
+ repeat_params = [num_channels_to_ablate] + \
+ len(activations.shape[:-1]) * [1]
+ self.activations = activations[input_batch_index, :, :].clone(
+ ).unsqueeze(0).repeat(*repeat_params)
+
+
+class AblationLayerFasterRCNN(AblationLayer):
+ def __init__(self):
+ super(AblationLayerFasterRCNN, self).__init__()
+
+ def set_next_batch(
+ self,
+ input_batch_index,
+ activations,
+ num_channels_to_ablate):
+ """ Extract the next batch member from activations,
+ and repeat it num_channels_to_ablate times.
+ """
+ self.activations = OrderedDict()
+ for key, value in activations.items():
+ fpn_activation = value[input_batch_index,
+ :, :, :].clone().unsqueeze(0)
+ self.activations[key] = fpn_activation.repeat(
+ num_channels_to_ablate, 1, 1, 1)
+
+ def __call__(self, x):
+ result = self.activations
+ layers = {0: '0', 1: '1', 2: '2', 3: '3', 4: 'pool'}
+ num_channels_to_ablate = result['pool'].size(0)
+ for i in range(num_channels_to_ablate):
+ pyramid_layer = int(self.indices[i] / 256)
+ index_in_pyramid_layer = int(self.indices[i] % 256)
+ result[layers[pyramid_layer]][i,
+ index_in_pyramid_layer, :, :] = -1000
+ return result
diff --git a/pytorch_grad_cam/activations_and_gradients.py b/pytorch_grad_cam/activations_and_gradients.py
new file mode 100644
index 0000000..0c2071e
--- /dev/null
+++ b/pytorch_grad_cam/activations_and_gradients.py
@@ -0,0 +1,46 @@
+class ActivationsAndGradients:
+ """ Class for extracting activations and
+ registering gradients from targetted intermediate layers """
+
+ def __init__(self, model, target_layers, reshape_transform):
+ self.model = model
+ self.gradients = []
+ self.activations = []
+ self.reshape_transform = reshape_transform
+ self.handles = []
+ for target_layer in target_layers:
+ self.handles.append(
+ target_layer.register_forward_hook(self.save_activation))
+ # Because of https://github.com/pytorch/pytorch/issues/61519,
+ # we don't use backward hook to record gradients.
+ self.handles.append(
+ target_layer.register_forward_hook(self.save_gradient))
+
+ def save_activation(self, module, input, output):
+ activation = output
+
+ if self.reshape_transform is not None:
+ activation = self.reshape_transform(activation)
+ self.activations.append(activation.cpu().detach())
+
+ def save_gradient(self, module, input, output):
+ if not hasattr(output, "requires_grad") or not output.requires_grad:
+ # You can only register hooks on tensor requires grad.
+ return
+
+ # Gradients are computed in reverse order
+ def _store_grad(grad):
+ if self.reshape_transform is not None:
+ grad = self.reshape_transform(grad)
+ self.gradients = [grad.cpu().detach()] + self.gradients
+
+ output.register_hook(_store_grad)
+
+ def __call__(self, x):
+ self.gradients = []
+ self.activations = []
+ return self.model(x)
+
+ def release(self):
+ for handle in self.handles:
+ handle.remove()
diff --git a/pytorch_grad_cam/base_cam.py b/pytorch_grad_cam/base_cam.py
new file mode 100644
index 0000000..7ee1929
--- /dev/null
+++ b/pytorch_grad_cam/base_cam.py
@@ -0,0 +1,203 @@
+import numpy as np
+import torch
+import ttach as tta
+from typing import Callable, List, Tuple
+from pytorch_grad_cam.activations_and_gradients import ActivationsAndGradients
+from pytorch_grad_cam.utils.svd_on_activations import get_2d_projection
+from pytorch_grad_cam.utils.image import scale_cam_image
+from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
+
+
+class BaseCAM:
+ def __init__(self,
+ model: torch.nn.Module,
+ target_layers: List[torch.nn.Module],
+ use_cuda: bool = False,
+ reshape_transform: Callable = None,
+ compute_input_gradient: bool = False,
+ uses_gradients: bool = True) -> None:
+ self.model = model.eval()
+ self.target_layers = target_layers
+ self.cuda = use_cuda
+ if self.cuda:
+ self.model = model.cuda()
+ self.reshape_transform = reshape_transform
+ self.compute_input_gradient = compute_input_gradient
+ self.uses_gradients = uses_gradients
+ self.activations_and_grads = ActivationsAndGradients(
+ self.model, target_layers, reshape_transform)
+
+ """ Get a vector of weights for every channel in the target layer.
+ Methods that return weights channels,
+ will typically need to only implement this function. """
+
+ def get_cam_weights(self,
+ input_tensor: torch.Tensor,
+ target_layers: List[torch.nn.Module],
+ targets: List[torch.nn.Module],
+ activations: torch.Tensor,
+ grads: torch.Tensor) -> np.ndarray:
+ raise Exception("Not Implemented")
+
+ def get_cam_image(self,
+ input_tensor: torch.Tensor,
+ target_layer: torch.nn.Module,
+ targets: List[torch.nn.Module],
+ activations: torch.Tensor,
+ grads: torch.Tensor,
+ eigen_smooth: bool = False) -> np.ndarray:
+
+ weights = self.get_cam_weights(input_tensor,
+ target_layer,
+ targets,
+ activations,
+ grads)
+ weighted_activations = weights[:, :, None, None] * activations
+ if eigen_smooth:
+ cam = get_2d_projection(weighted_activations)
+ else:
+ cam = weighted_activations.sum(axis=1)
+ return cam
+
+ def forward(self,
+ input_tensor: torch.Tensor,
+ targets: List[torch.nn.Module],
+ eigen_smooth: bool = False) -> np.ndarray:
+
+ if self.cuda:
+ input_tensor = input_tensor.cuda()
+
+ if self.compute_input_gradient:
+ input_tensor = torch.autograd.Variable(input_tensor,
+ requires_grad=True)
+
+ outputs = self.activations_and_grads(input_tensor)
+ if targets is None:
+ target_categories = np.argmax(outputs.cpu().data.numpy(), axis=-1)
+ targets = [ClassifierOutputTarget(
+ category) for category in target_categories]
+
+ if self.uses_gradients:
+ self.model.zero_grad()
+ loss = sum([target(output)
+ for target, output in zip(targets, outputs)])
+ loss.backward(retain_graph=True)
+
+ # In most of the saliency attribution papers, the saliency is
+ # computed with a single target layer.
+ # Commonly it is the last convolutional layer.
+ # Here we support passing a list with multiple target layers.
+ # It will compute the saliency image for every image,
+ # and then aggregate them (with a default mean aggregation).
+ # This gives you more flexibility in case you just want to
+ # use all conv layers for example, all Batchnorm layers,
+ # or something else.
+ cam_per_layer = self.compute_cam_per_layer(input_tensor,
+ targets,
+ eigen_smooth)
+ return self.aggregate_multi_layers(cam_per_layer)
+
+ def get_target_width_height(self,
+ input_tensor: torch.Tensor) -> Tuple[int, int]:
+ width, height = input_tensor.size(-1), input_tensor.size(-2)
+ return width, height
+
+ def compute_cam_per_layer(
+ self,
+ input_tensor: torch.Tensor,
+ targets: List[torch.nn.Module],
+ eigen_smooth: bool) -> np.ndarray:
+ activations_list = [a.cpu().data.numpy()
+ for a in self.activations_and_grads.activations]
+ grads_list = [g.cpu().data.numpy()
+ for g in self.activations_and_grads.gradients]
+ target_size = self.get_target_width_height(input_tensor)
+
+ cam_per_target_layer = []
+ # Loop over the saliency image from every layer
+ for i in range(len(self.target_layers)):
+ target_layer = self.target_layers[i]
+ layer_activations = None
+ layer_grads = None
+ if i < len(activations_list):
+ layer_activations = activations_list[i]
+ if i < len(grads_list):
+ layer_grads = grads_list[i]
+
+ cam = self.get_cam_image(input_tensor,
+ target_layer,
+ targets,
+ layer_activations,
+ layer_grads,
+ eigen_smooth)
+ cam = np.maximum(cam, 0)
+ scaled = scale_cam_image(cam, target_size)
+ cam_per_target_layer.append(scaled[:, None, :])
+
+ return cam_per_target_layer
+
+ def aggregate_multi_layers(
+ self,
+ cam_per_target_layer: np.ndarray) -> np.ndarray:
+ cam_per_target_layer = np.concatenate(cam_per_target_layer, axis=1)
+ cam_per_target_layer = np.maximum(cam_per_target_layer, 0)
+ result = np.mean(cam_per_target_layer, axis=1)
+ return scale_cam_image(result)
+
+ def forward_augmentation_smoothing(self,
+ input_tensor: torch.Tensor,
+ targets: List[torch.nn.Module],
+ eigen_smooth: bool = False) -> np.ndarray:
+ transforms = tta.Compose(
+ [
+ tta.HorizontalFlip(),
+ tta.Multiply(factors=[0.9, 1, 1.1]),
+ ]
+ )
+ cams = []
+ for transform in transforms:
+ augmented_tensor = transform.augment_image(input_tensor)
+ cam = self.forward(augmented_tensor,
+ targets,
+ eigen_smooth)
+
+ # The ttach library expects a tensor of size BxCxHxW
+ cam = cam[:, None, :, :]
+ cam = torch.from_numpy(cam)
+ cam = transform.deaugment_mask(cam)
+
+ # Back to numpy float32, HxW
+ cam = cam.numpy()
+ cam = cam[:, 0, :, :]
+ cams.append(cam)
+
+ cam = np.mean(np.float32(cams), axis=0)
+ return cam
+
+ def __call__(self,
+ input_tensor: torch.Tensor,
+ targets: List[torch.nn.Module] = None,
+ aug_smooth: bool = False,
+ eigen_smooth: bool = False) -> np.ndarray:
+
+ # Smooth the CAM result with test time augmentation
+ if aug_smooth is True:
+ return self.forward_augmentation_smoothing(
+ input_tensor, targets, eigen_smooth)
+
+ return self.forward(input_tensor,
+ targets, eigen_smooth)
+
+ def __del__(self):
+ self.activations_and_grads.release()
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, exc_type, exc_value, exc_tb):
+ self.activations_and_grads.release()
+ if isinstance(exc_value, IndexError):
+ # Handle IndexError here...
+ print(
+ f"An exception occurred in CAM with block: {exc_type}. Message: {exc_value}")
+ return True
diff --git a/pytorch_grad_cam/eigen_cam.py b/pytorch_grad_cam/eigen_cam.py
new file mode 100644
index 0000000..fd6d6bc
--- /dev/null
+++ b/pytorch_grad_cam/eigen_cam.py
@@ -0,0 +1,23 @@
+from pytorch_grad_cam.base_cam import BaseCAM
+from pytorch_grad_cam.utils.svd_on_activations import get_2d_projection
+
+# https://arxiv.org/abs/2008.00299
+
+
+class EigenCAM(BaseCAM):
+ def __init__(self, model, target_layers, use_cuda=False,
+ reshape_transform=None):
+ super(EigenCAM, self).__init__(model,
+ target_layers,
+ use_cuda,
+ reshape_transform,
+ uses_gradients=False)
+
+ def get_cam_image(self,
+ input_tensor,
+ target_layer,
+ target_category,
+ activations,
+ grads,
+ eigen_smooth):
+ return get_2d_projection(activations)
diff --git a/pytorch_grad_cam/eigen_grad_cam.py b/pytorch_grad_cam/eigen_grad_cam.py
new file mode 100644
index 0000000..3932a96
--- /dev/null
+++ b/pytorch_grad_cam/eigen_grad_cam.py
@@ -0,0 +1,21 @@
+from pytorch_grad_cam.base_cam import BaseCAM
+from pytorch_grad_cam.utils.svd_on_activations import get_2d_projection
+
+# Like Eigen CAM: https://arxiv.org/abs/2008.00299
+# But multiply the activations x gradients
+
+
+class EigenGradCAM(BaseCAM):
+ def __init__(self, model, target_layers, use_cuda=False,
+ reshape_transform=None):
+ super(EigenGradCAM, self).__init__(model, target_layers, use_cuda,
+ reshape_transform)
+
+ def get_cam_image(self,
+ input_tensor,
+ target_layer,
+ target_category,
+ activations,
+ grads,
+ eigen_smooth):
+ return get_2d_projection(grads * activations)
diff --git a/pytorch_grad_cam/feature_factorization/__init__.py b/pytorch_grad_cam/feature_factorization/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/pytorch_grad_cam/feature_factorization/__pycache__/__init__.cpython-37.pyc b/pytorch_grad_cam/feature_factorization/__pycache__/__init__.cpython-37.pyc
new file mode 100644
index 0000000..6b20a42
Binary files /dev/null and b/pytorch_grad_cam/feature_factorization/__pycache__/__init__.cpython-37.pyc differ
diff --git a/pytorch_grad_cam/feature_factorization/__pycache__/__init__.cpython-38.pyc b/pytorch_grad_cam/feature_factorization/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000..6b3636f
Binary files /dev/null and b/pytorch_grad_cam/feature_factorization/__pycache__/__init__.cpython-38.pyc differ
diff --git a/pytorch_grad_cam/feature_factorization/__pycache__/deep_feature_factorization.cpython-37.pyc b/pytorch_grad_cam/feature_factorization/__pycache__/deep_feature_factorization.cpython-37.pyc
new file mode 100644
index 0000000..0b22a48
Binary files /dev/null and b/pytorch_grad_cam/feature_factorization/__pycache__/deep_feature_factorization.cpython-37.pyc differ
diff --git a/pytorch_grad_cam/feature_factorization/__pycache__/deep_feature_factorization.cpython-38.pyc b/pytorch_grad_cam/feature_factorization/__pycache__/deep_feature_factorization.cpython-38.pyc
new file mode 100644
index 0000000..2f61f2f
Binary files /dev/null and b/pytorch_grad_cam/feature_factorization/__pycache__/deep_feature_factorization.cpython-38.pyc differ
diff --git a/pytorch_grad_cam/feature_factorization/deep_feature_factorization.py b/pytorch_grad_cam/feature_factorization/deep_feature_factorization.py
new file mode 100644
index 0000000..b9db2c3
--- /dev/null
+++ b/pytorch_grad_cam/feature_factorization/deep_feature_factorization.py
@@ -0,0 +1,131 @@
+import numpy as np
+from PIL import Image
+import torch
+from typing import Callable, List, Tuple, Optional
+from sklearn.decomposition import NMF
+from pytorch_grad_cam.activations_and_gradients import ActivationsAndGradients
+from pytorch_grad_cam.utils.image import scale_cam_image, create_labels_legend, show_factorization_on_image
+
+
+def dff(activations: np.ndarray, n_components: int = 5):
+ """ Compute Deep Feature Factorization on a 2d Activations tensor.
+
+ :param activations: A numpy array of shape batch x channels x height x width
+ :param n_components: The number of components for the non negative matrix factorization
+ :returns: A tuple of the concepts (a numpy array with shape channels x components),
+ and the explanation heatmaps (a numpy arary with shape batch x height x width)
+ """
+
+ batch_size, channels, h, w = activations.shape
+ reshaped_activations = activations.transpose((1, 0, 2, 3))
+ reshaped_activations[np.isnan(reshaped_activations)] = 0
+ reshaped_activations = reshaped_activations.reshape(
+ reshaped_activations.shape[0], -1)
+ offset = reshaped_activations.min(axis=-1)
+ reshaped_activations = reshaped_activations - offset[:, None]
+
+ model = NMF(n_components=n_components, init='random', random_state=0)
+ W = model.fit_transform(reshaped_activations)
+ H = model.components_
+ concepts = W + offset[:, None]
+ explanations = H.reshape(n_components, batch_size, h, w)
+ explanations = explanations.transpose((1, 0, 2, 3))
+ return concepts, explanations
+
+
+class DeepFeatureFactorization:
+ """ Deep Feature Factorization: https://arxiv.org/abs/1806.10206
+ This gets a model andcomputes the 2D activations for a target layer,
+ and computes Non Negative Matrix Factorization on the activations.
+
+ Optionally it runs a computation on the concept embeddings,
+ like running a classifier on them.
+
+ The explanation heatmaps are scalled to the range [0, 1]
+ and to the input tensor width and height.
+ """
+
+ def __init__(self,
+ model: torch.nn.Module,
+ target_layer: torch.nn.Module,
+ reshape_transform: Callable = None,
+ computation_on_concepts=None
+ ):
+ self.model = model
+ self.computation_on_concepts = computation_on_concepts
+ self.activations_and_grads = ActivationsAndGradients(
+ self.model, [target_layer], reshape_transform)
+
+ def __call__(self,
+ input_tensor: torch.Tensor,
+ n_components: int = 16):
+ batch_size, channels, h, w = input_tensor.size()
+ _ = self.activations_and_grads(input_tensor)
+
+ with torch.no_grad():
+ activations = self.activations_and_grads.activations[0].cpu(
+ ).numpy()
+
+ concepts, explanations = dff(activations, n_components=n_components)
+
+ processed_explanations = []
+
+ for batch in explanations:
+ processed_explanations.append(scale_cam_image(batch, (w, h)))
+
+ if self.computation_on_concepts:
+ with torch.no_grad():
+ concept_tensors = torch.from_numpy(
+ np.float32(concepts).transpose((1, 0)))
+ concept_outputs = self.computation_on_concepts(
+ concept_tensors).cpu().numpy()
+ return concepts, processed_explanations, concept_outputs
+ else:
+ return concepts, processed_explanations
+
+ def __del__(self):
+ self.activations_and_grads.release()
+
+ def __exit__(self, exc_type, exc_value, exc_tb):
+ self.activations_and_grads.release()
+ if isinstance(exc_value, IndexError):
+ # Handle IndexError here...
+ print(
+ f"An exception occurred in ActivationSummary with block: {exc_type}. Message: {exc_value}")
+ return True
+
+
+def run_dff_on_image(model: torch.nn.Module,
+ target_layer: torch.nn.Module,
+ classifier: torch.nn.Module,
+ img_pil: Image,
+ img_tensor: torch.Tensor,
+ reshape_transform=Optional[Callable],
+ n_components: int = 5,
+ top_k: int = 2) -> np.ndarray:
+ """ Helper function to create a Deep Feature Factorization visualization for a single image.
+ TBD: Run this on a batch with several images.
+ """
+ rgb_img_float = np.array(img_pil) / 255
+ dff = DeepFeatureFactorization(model=model,
+ reshape_transform=reshape_transform,
+ target_layer=target_layer,
+ computation_on_concepts=classifier)
+
+ concepts, batch_explanations, concept_outputs = dff(
+ img_tensor[None, :], n_components)
+
+ concept_outputs = torch.softmax(
+ torch.from_numpy(concept_outputs),
+ axis=-1).numpy()
+ concept_label_strings = create_labels_legend(concept_outputs,
+ labels=model.config.id2label,
+ top_k=top_k)
+ visualization = show_factorization_on_image(
+ rgb_img_float,
+ batch_explanations[0],
+ image_weight=0.3,
+ concept_labels=concept_label_strings)
+
+ result = np.hstack((np.array(img_pil), visualization))
+ return result
diff --git a/pytorch_grad_cam/fullgrad_cam.py b/pytorch_grad_cam/fullgrad_cam.py
new file mode 100644
index 0000000..1a2685e
--- /dev/null
+++ b/pytorch_grad_cam/fullgrad_cam.py
@@ -0,0 +1,95 @@
+import numpy as np
+import torch
+from pytorch_grad_cam.base_cam import BaseCAM
+from pytorch_grad_cam.utils.find_layers import find_layer_predicate_recursive
+from pytorch_grad_cam.utils.svd_on_activations import get_2d_projection
+from pytorch_grad_cam.utils.image import scale_accross_batch_and_channels, scale_cam_image
+
+# https://arxiv.org/abs/1905.00780
+
+
+class FullGrad(BaseCAM):
+ def __init__(self, model, target_layers, use_cuda=False,
+ reshape_transform=None):
+ if len(target_layers) > 0:
+ print(
+ "Warning: target_layers is ignored in FullGrad. All bias layers will be used instead")
+
+ def layer_with_2D_bias(layer):
+ bias_target_layers = [torch.nn.Conv2d, torch.nn.BatchNorm2d]
+ if type(layer) in bias_target_layers and layer.bias is not None:
+ return True
+ return False
+ target_layers = find_layer_predicate_recursive(
+ model, layer_with_2D_bias)
+ super(
+ FullGrad,
+ self).__init__(
+ model,
+ target_layers,
+ use_cuda,
+ reshape_transform,
+ compute_input_gradient=True)
+ self.bias_data = [self.get_bias_data(
+ layer).cpu().numpy() for layer in target_layers]
+
+ def get_bias_data(self, layer):
+ # Borrowed from official paper impl:
+ # https://github.com/idiap/fullgrad-saliency/blob/master/saliency/tensor_extractor.py#L47
+ if isinstance(layer, torch.nn.BatchNorm2d):
+ bias = - (layer.running_mean * layer.weight
+ / torch.sqrt(layer.running_var + layer.eps)) + layer.bias
+ return bias.data
+ else:
+ return layer.bias.data
+
+ def compute_cam_per_layer(
+ self,
+ input_tensor,
+ target_category,
+ eigen_smooth):
+ input_grad = input_tensor.grad.data.cpu().numpy()
+ grads_list = [g.cpu().data.numpy() for g in
+ self.activations_and_grads.gradients]
+ cam_per_target_layer = []
+ target_size = self.get_target_width_height(input_tensor)
+
+ gradient_multiplied_input = input_grad * input_tensor.data.cpu().numpy()
+ gradient_multiplied_input = np.abs(gradient_multiplied_input)
+ gradient_multiplied_input = scale_accross_batch_and_channels(
+ gradient_multiplied_input,
+ target_size)
+ cam_per_target_layer.append(gradient_multiplied_input)
+
+ # Loop over the saliency image from every layer
+ assert(len(self.bias_data) == len(grads_list))
+ for bias, grads in zip(self.bias_data, grads_list):
+ bias = bias[None, :, None, None]
+ # In the paper they take the absolute value,
+ # but possibily taking only the positive gradients will work
+ # better.
+ bias_grad = np.abs(bias * grads)
+ result = scale_accross_batch_and_channels(
+ bias_grad, target_size)
+ result = np.sum(result, axis=1)
+ cam_per_target_layer.append(result[:, None, :])
+ cam_per_target_layer = np.concatenate(cam_per_target_layer, axis=1)
+ if eigen_smooth:
+ # Resize to a smaller image, since this method typically has a very large number of channels,
+ # and then consumes a lot of memory
+ cam_per_target_layer = scale_accross_batch_and_channels(
+ cam_per_target_layer, (target_size[0] // 8, target_size[1] // 8))
+ cam_per_target_layer = get_2d_projection(cam_per_target_layer)
+ cam_per_target_layer = cam_per_target_layer[:, None, :, :]
+ cam_per_target_layer = scale_accross_batch_and_channels(
+ cam_per_target_layer,
+ target_size)
+ else:
+ cam_per_target_layer = np.sum(
+ cam_per_target_layer, axis=1)[:, None, :]
+
+ return cam_per_target_layer
+
+ def aggregate_multi_layers(self, cam_per_target_layer):
+ result = np.sum(cam_per_target_layer, axis=1)
+ return scale_cam_image(result)
diff --git a/pytorch_grad_cam/grad_cam.py b/pytorch_grad_cam/grad_cam.py
new file mode 100644
index 0000000..025bf45
--- /dev/null
+++ b/pytorch_grad_cam/grad_cam.py
@@ -0,0 +1,22 @@
+import numpy as np
+from pytorch_grad_cam.base_cam import BaseCAM
+
+
+class GradCAM(BaseCAM):
+ def __init__(self, model, target_layers, use_cuda=False,
+ reshape_transform=None):
+ super(
+ GradCAM,
+ self).__init__(
+ model,
+ target_layers,
+ use_cuda,
+ reshape_transform)
+
+ def get_cam_weights(self,
+ input_tensor,
+ target_layer,
+ target_category,
+ activations,
+ grads):
+ return np.mean(grads, axis=(2, 3))
diff --git a/pytorch_grad_cam/grad_cam_elementwise.py b/pytorch_grad_cam/grad_cam_elementwise.py
new file mode 100644
index 0000000..2698d47
--- /dev/null
+++ b/pytorch_grad_cam/grad_cam_elementwise.py
@@ -0,0 +1,30 @@
+import numpy as np
+from pytorch_grad_cam.base_cam import BaseCAM
+from pytorch_grad_cam.utils.svd_on_activations import get_2d_projection
+
+
+class GradCAMElementWise(BaseCAM):
+ def __init__(self, model, target_layers, use_cuda=False,
+ reshape_transform=None):
+ super(
+ GradCAMElementWise,
+ self).__init__(
+ model,
+ target_layers,
+ use_cuda,
+ reshape_transform)
+
+ def get_cam_image(self,
+ input_tensor,
+ target_layer,
+ target_category,
+ activations,
+ grads,
+ eigen_smooth):
+ elementwise_activations = np.maximum(grads * activations, 0)
+
+ if eigen_smooth:
+ cam = get_2d_projection(elementwise_activations)
+ else:
+ cam = elementwise_activations.sum(axis=1)
+ return cam
diff --git a/pytorch_grad_cam/grad_cam_plusplus.py b/pytorch_grad_cam/grad_cam_plusplus.py
new file mode 100644
index 0000000..4466826
--- /dev/null
+++ b/pytorch_grad_cam/grad_cam_plusplus.py
@@ -0,0 +1,32 @@
+import numpy as np
+from pytorch_grad_cam.base_cam import BaseCAM
+
+# https://arxiv.org/abs/1710.11063
+
+
+class GradCAMPlusPlus(BaseCAM):
+ def __init__(self, model, target_layers, use_cuda=False,
+ reshape_transform=None):
+ super(GradCAMPlusPlus, self).__init__(model, target_layers, use_cuda,
+ reshape_transform)
+
+ def get_cam_weights(self,
+ input_tensor,
+ target_layers,
+ target_category,
+ activations,
+ grads):
+ grads_power_2 = grads**2
+ grads_power_3 = grads_power_2 * grads
+ # Equation 19 in https://arxiv.org/abs/1710.11063
+ sum_activations = np.sum(activations, axis=(2, 3))
+ eps = 0.000001
+ aij = grads_power_2 / (2 * grads_power_2 +
+ sum_activations[:, :, None, None] * grads_power_3 + eps)
+ # Now bring back the ReLU from eq.7 in the paper,
+ # And zero out aijs where the activations are 0
+ aij = np.where(grads != 0, aij, 0)
+
+ weights = np.maximum(grads, 0) * aij
+ weights = np.sum(weights, axis=(2, 3))
+ return weights
diff --git a/pytorch_grad_cam/guided_backprop.py b/pytorch_grad_cam/guided_backprop.py
new file mode 100644
index 0000000..602fbf3
--- /dev/null
+++ b/pytorch_grad_cam/guided_backprop.py
@@ -0,0 +1,100 @@
+import numpy as np
+import torch
+from torch.autograd import Function
+from pytorch_grad_cam.utils.find_layers import replace_all_layer_type_recursive
+
+
+class GuidedBackpropReLU(Function):
+ @staticmethod
+ def forward(self, input_img):
+ positive_mask = (input_img > 0).type_as(input_img)
+ output = torch.addcmul(
+ torch.zeros(
+ input_img.size()).type_as(input_img),
+ input_img,
+ positive_mask)
+ self.save_for_backward(input_img, output)
+ return output
+
+ @staticmethod
+ def backward(self, grad_output):
+ input_img, output = self.saved_tensors
+ grad_input = None
+
+ positive_mask_1 = (input_img > 0).type_as(grad_output)
+ positive_mask_2 = (grad_output > 0).type_as(grad_output)
+ grad_input = torch.addcmul(
+ torch.zeros(
+ input_img.size()).type_as(input_img),
+ torch.addcmul(
+ torch.zeros(
+ input_img.size()).type_as(input_img),
+ grad_output,
+ positive_mask_1),
+ positive_mask_2)
+ return grad_input
+
+
+class GuidedBackpropReLUasModule(torch.nn.Module):
+ def __init__(self):
+ super(GuidedBackpropReLUasModule, self).__init__()
+
+ def forward(self, input_img):
+ return GuidedBackpropReLU.apply(input_img)
+
+
+class GuidedBackpropReLUModel:
+ def __init__(self, model, use_cuda):
+ self.model = model
+ self.model.eval()
+ self.cuda = use_cuda
+ if self.cuda:
+ self.model = self.model.cuda()
+
+ def forward(self, input_img):
+ return self.model(input_img)
+
+ def recursive_replace_relu_with_guidedrelu(self, module_top):
+
+ for idx, module in module_top._modules.items():
+ self.recursive_replace_relu_with_guidedrelu(module)
+ if module.__class__.__name__ == 'ReLU':
+ module_top._modules[idx] = GuidedBackpropReLU.apply
+ print("b")
+
+ def recursive_replace_guidedrelu_with_relu(self, module_top):
+ try:
+ for idx, module in module_top._modules.items():
+ self.recursive_replace_guidedrelu_with_relu(module)
+ if module == GuidedBackpropReLU.apply:
+ module_top._modules[idx] = torch.nn.ReLU()
+ except BaseException:
+ pass
+
+ def __call__(self, input_img, target_category=None):
+ replace_all_layer_type_recursive(self.model,
+ torch.nn.ReLU,
+ GuidedBackpropReLUasModule())
+
+ if self.cuda:
+ input_img = input_img.cuda()
+
+ input_img = input_img.requires_grad_(True)
+
+ output = self.forward(input_img)
+
+ if target_category is None:
+ target_category = np.argmax(output.cpu().data.numpy())
+
+ loss = output[0, target_category]
+ loss.backward(retain_graph=True)
+
+ output = input_img.grad.cpu().data.numpy()
+ output = output[0, :, :, :]
+ output = output.transpose((1, 2, 0))
+
+ replace_all_layer_type_recursive(self.model,
+ GuidedBackpropReLUasModule,
+ torch.nn.ReLU())
+
+ return output
diff --git a/pytorch_grad_cam/hirescam.py b/pytorch_grad_cam/hirescam.py
new file mode 100644
index 0000000..381d8d4
--- /dev/null
+++ b/pytorch_grad_cam/hirescam.py
@@ -0,0 +1,32 @@
+import numpy as np
+from pytorch_grad_cam.base_cam import BaseCAM
+from pytorch_grad_cam.utils.svd_on_activations import get_2d_projection
+
+
+class HiResCAM(BaseCAM):
+ def __init__(self, model, target_layers, use_cuda=False,
+ reshape_transform=None):
+ super(
+ HiResCAM,
+ self).__init__(
+ model,
+ target_layers,
+ use_cuda,
+ reshape_transform)
+
+ def get_cam_image(self,
+ input_tensor,
+ target_layer,
+ target_category,
+ activations,
+ grads,
+ eigen_smooth):
+ elementwise_activations = grads * activations
+
+ if eigen_smooth:
+ print(
+ "Warning: HiResCAM's faithfulness guarantees do not hold if smoothing is applied")
+ cam = get_2d_projection(elementwise_activations)
+ else:
+ cam = elementwise_activations.sum(axis=1)
+ return cam
diff --git a/pytorch_grad_cam/layer_cam.py b/pytorch_grad_cam/layer_cam.py
new file mode 100644
index 0000000..971443d
--- /dev/null
+++ b/pytorch_grad_cam/layer_cam.py
@@ -0,0 +1,36 @@
+import numpy as np
+from pytorch_grad_cam.base_cam import BaseCAM
+from pytorch_grad_cam.utils.svd_on_activations import get_2d_projection
+
+# https://ieeexplore.ieee.org/document/9462463
+
+
+class LayerCAM(BaseCAM):
+ def __init__(
+ self,
+ model,
+ target_layers,
+ use_cuda=False,
+ reshape_transform=None):
+ super(
+ LayerCAM,
+ self).__init__(
+ model,
+ target_layers,
+ use_cuda,
+ reshape_transform)
+
+ def get_cam_image(self,
+ input_tensor,
+ target_layer,
+ target_category,
+ activations,
+ grads,
+ eigen_smooth):
+ spatial_weighted_activations = np.maximum(grads, 0) * activations
+
+ if eigen_smooth:
+ cam = get_2d_projection(spatial_weighted_activations)
+ else:
+ cam = spatial_weighted_activations.sum(axis=1)
+ return cam
diff --git a/pytorch_grad_cam/metrics/__init__.py b/pytorch_grad_cam/metrics/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/pytorch_grad_cam/metrics/__pycache__/__init__.cpython-37.pyc b/pytorch_grad_cam/metrics/__pycache__/__init__.cpython-37.pyc
new file mode 100644
index 0000000..9278b11
Binary files /dev/null and b/pytorch_grad_cam/metrics/__pycache__/__init__.cpython-37.pyc differ
diff --git a/pytorch_grad_cam/metrics/__pycache__/__init__.cpython-38.pyc b/pytorch_grad_cam/metrics/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000..6ec40fa
Binary files /dev/null and b/pytorch_grad_cam/metrics/__pycache__/__init__.cpython-38.pyc differ
diff --git a/pytorch_grad_cam/metrics/__pycache__/cam_mult_image.cpython-37.pyc b/pytorch_grad_cam/metrics/__pycache__/cam_mult_image.cpython-37.pyc
new file mode 100644
index 0000000..4239082
Binary files /dev/null and b/pytorch_grad_cam/metrics/__pycache__/cam_mult_image.cpython-37.pyc differ
diff --git a/pytorch_grad_cam/metrics/__pycache__/cam_mult_image.cpython-38.pyc b/pytorch_grad_cam/metrics/__pycache__/cam_mult_image.cpython-38.pyc
new file mode 100644
index 0000000..cbf2ee5
Binary files /dev/null and b/pytorch_grad_cam/metrics/__pycache__/cam_mult_image.cpython-38.pyc differ
diff --git a/pytorch_grad_cam/metrics/__pycache__/perturbation_confidence.cpython-37.pyc b/pytorch_grad_cam/metrics/__pycache__/perturbation_confidence.cpython-37.pyc
new file mode 100644
index 0000000..b1b9b2e
Binary files /dev/null and b/pytorch_grad_cam/metrics/__pycache__/perturbation_confidence.cpython-37.pyc differ
diff --git a/pytorch_grad_cam/metrics/__pycache__/perturbation_confidence.cpython-38.pyc b/pytorch_grad_cam/metrics/__pycache__/perturbation_confidence.cpython-38.pyc
new file mode 100644
index 0000000..e788488
Binary files /dev/null and b/pytorch_grad_cam/metrics/__pycache__/perturbation_confidence.cpython-38.pyc differ
diff --git a/pytorch_grad_cam/metrics/__pycache__/road.cpython-37.pyc b/pytorch_grad_cam/metrics/__pycache__/road.cpython-37.pyc
new file mode 100644
index 0000000..83b5e13
Binary files /dev/null and b/pytorch_grad_cam/metrics/__pycache__/road.cpython-37.pyc differ
diff --git a/pytorch_grad_cam/metrics/__pycache__/road.cpython-38.pyc b/pytorch_grad_cam/metrics/__pycache__/road.cpython-38.pyc
new file mode 100644
index 0000000..309afc7
Binary files /dev/null and b/pytorch_grad_cam/metrics/__pycache__/road.cpython-38.pyc differ
diff --git a/pytorch_grad_cam/metrics/cam_mult_image.py b/pytorch_grad_cam/metrics/cam_mult_image.py
new file mode 100644
index 0000000..bd4bf8a
--- /dev/null
+++ b/pytorch_grad_cam/metrics/cam_mult_image.py
@@ -0,0 +1,37 @@
+import torch
+import numpy as np
+from typing import List, Callable
+from pytorch_grad_cam.metrics.perturbation_confidence import PerturbationConfidenceMetric
+
+
+def multiply_tensor_with_cam(input_tensor: torch.Tensor,
+ cam: torch.Tensor):
+ """ Multiply an input tensor (after normalization)
+ with a pixel attribution map
+ """
+ return input_tensor * cam
+
+
+class CamMultImageConfidenceChange(PerturbationConfidenceMetric):
+ def __init__(self):
+ super(CamMultImageConfidenceChange,
+ self).__init__(multiply_tensor_with_cam)
+
+
+class DropInConfidence(CamMultImageConfidenceChange):
+ def __init__(self):
+ super(DropInConfidence, self).__init__()
+
+ def __call__(self, *args, **kwargs):
+ scores = super(DropInConfidence, self).__call__(*args, **kwargs)
+ scores = -scores
+ return np.maximum(scores, 0)
+
+
+class IncreaseInConfidence(CamMultImageConfidenceChange):
+ def __init__(self):
+ super(IncreaseInConfidence, self).__init__()
+
+ def __call__(self, *args, **kwargs):
+ scores = super(IncreaseInConfidence, self).__call__(*args, **kwargs)
+ return np.float32(scores > 0)
diff --git a/pytorch_grad_cam/metrics/perturbation_confidence.py b/pytorch_grad_cam/metrics/perturbation_confidence.py
new file mode 100644
index 0000000..813ffc7
--- /dev/null
+++ b/pytorch_grad_cam/metrics/perturbation_confidence.py
@@ -0,0 +1,109 @@
+import torch
+import numpy as np
+from typing import List, Callable
+
+import numpy as np
+import cv2
+
+
+class PerturbationConfidenceMetric:
+ def __init__(self, perturbation):
+ self.perturbation = perturbation
+
+ def __call__(self, input_tensor: torch.Tensor,
+ cams: np.ndarray,
+ targets: List[Callable],
+ model: torch.nn.Module,
+ return_visualization=False,
+ return_diff=True):
+
+ if return_diff:
+ with torch.no_grad():
+ outputs = model(input_tensor)
+ scores = [target(output).cpu().numpy()
+ for target, output in zip(targets, outputs)]
+ scores = np.float32(scores)
+
+ batch_size = input_tensor.size(0)
+ perturbated_tensors = []
+ for i in range(batch_size):
+ cam = cams[i]
+ tensor = self.perturbation(input_tensor[i, ...].cpu(),
+ torch.from_numpy(cam))
+ tensor = tensor.to(input_tensor.device)
+ perturbated_tensors.append(tensor.unsqueeze(0))
+ perturbated_tensors = torch.cat(perturbated_tensors)
+
+ with torch.no_grad():
+ outputs_after_imputation = model(perturbated_tensors)
+ scores_after_imputation = [
+ target(output).cpu().numpy() for target, output in zip(
+ targets, outputs_after_imputation)]
+ scores_after_imputation = np.float32(scores_after_imputation)
+
+ if return_diff:
+ result = scores_after_imputation - scores
+ else:
+ result = scores_after_imputation
+
+ if return_visualization:
+ return result, perturbated_tensors
+ else:
+ return result
+
+
+class RemoveMostRelevantFirst:
+ def __init__(self, percentile, imputer):
+ self.percentile = percentile
+ self.imputer = imputer
+
+ def __call__(self, input_tensor, mask):
+ imputer = self.imputer
+ if self.percentile != 'auto':
+ threshold = np.percentile(mask.cpu().numpy(), self.percentile)
+ binary_mask = np.float32(mask < threshold)
+ else:
+ _, binary_mask = cv2.threshold(
+ np.uint8(mask * 255), 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
+
+ binary_mask = torch.from_numpy(binary_mask)
+ binary_mask = binary_mask.to(mask.device)
+ return imputer(input_tensor, binary_mask)
+
+
+class RemoveLeastRelevantFirst(RemoveMostRelevantFirst):
+ def __init__(self, percentile, imputer):
+ super(RemoveLeastRelevantFirst, self).__init__(percentile, imputer)
+
+ def __call__(self, input_tensor, mask):
+ return super(RemoveLeastRelevantFirst, self).__call__(
+ input_tensor, 1 - mask)
+
+
+class AveragerAcrossThresholds:
+ def __init__(
+ self,
+ imputer,
+ percentiles=[
+ 10,
+ 20,
+ 30,
+ 40,
+ 50,
+ 60,
+ 70,
+ 80,
+ 90]):
+ self.imputer = imputer
+ self.percentiles = percentiles
+
+ def __call__(self,
+ input_tensor: torch.Tensor,
+ cams: np.ndarray,
+ targets: List[Callable],
+ model: torch.nn.Module):
+ scores = []
+ for percentile in self.percentiles:
+ imputer = self.imputer(percentile)
+ scores.append(imputer(input_tensor, cams, targets, model))
+ return np.mean(np.float32(scores), axis=0)
diff --git a/pytorch_grad_cam/metrics/road.py b/pytorch_grad_cam/metrics/road.py
new file mode 100644
index 0000000..7b09c4b
--- /dev/null
+++ b/pytorch_grad_cam/metrics/road.py
@@ -0,0 +1,181 @@
+# A Consistent and Efficient Evaluation Strategy for Attribution Methods
+# https://arxiv.org/abs/2202.00449
+# Taken from https://raw.githubusercontent.com/tleemann/road_evaluation/main/imputations.py
+# MIT License
+
+# Copyright (c) 2022 Tobias Leemann
+
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+
+
+# Implementations of our imputation models.
+import torch
+import numpy as np
+from scipy.sparse import lil_matrix, csc_matrix
+from scipy.sparse.linalg import spsolve
+from typing import List, Callable
+from pytorch_grad_cam.metrics.perturbation_confidence import PerturbationConfidenceMetric, \
+ AveragerAcrossThresholds, \
+ RemoveMostRelevantFirst, \
+ RemoveLeastRelevantFirst
+
+# The weights of the surrounding pixels
+neighbors_weights = [((1, 1), 1 / 12),
+ ((0, 1), 1 / 6),
+ ((-1, 1), 1 / 12),
+ ((1, -1), 1 / 12),
+ ((0, -1), 1 / 6),
+ ((-1, -1), 1 / 12),
+ ((1, 0), 1 / 6),
+ ((-1, 0), 1 / 6)]
+
+
+class NoisyLinearImputer:
+ def __init__(self,
+ noise: float = 0.01,
+ weighting: List[float] = neighbors_weights):
+ """
+ Noisy linear imputation.
+ noise: magnitude of noise to add (absolute, set to 0 for no noise)
+ weighting: Weights of the neighboring pixels in the computation.
+ List of tuples of (offset, weight)
+ """
+ self.noise = noise
+ self.weighting = neighbors_weights
+
+ @staticmethod
+ def add_offset_to_indices(indices, offset, mask_shape):
+ """ Add the corresponding offset to the indices.
+ Return new indices plus a valid bit-vector. """
+ cord1 = indices % mask_shape[1]
+ cord0 = indices // mask_shape[1]
+ cord0 += offset[0]
+ cord1 += offset[1]
+ valid = ((cord0 < 0) | (cord1 < 0) |
+ (cord0 >= mask_shape[0]) |
+ (cord1 >= mask_shape[1]))
+ return ~valid, indices + offset[0] * mask_shape[1] + offset[1]
+
+ @staticmethod
+ def setup_sparse_system(mask, img, neighbors_weights):
+ """ Vectorized version to set up the equation system.
+ mask: (H, W)-tensor of missing pixels.
+ Image: (H, W, C)-tensor of all values.
+ Return (N,N)-System matrix, (N,C)-Right hand side for each of the C channels.
+ """
+ maskflt = mask.flatten()
+ imgflat = img.reshape((img.shape[0], -1))
+ # Indices that are imputed in the flattened mask:
+ indices = np.argwhere(maskflt == 0).flatten()
+ coords_to_vidx = np.zeros(len(maskflt), dtype=int)
+ coords_to_vidx[indices] = np.arange(len(indices))
+ numEquations = len(indices)
+ # System matrix:
+ A = lil_matrix((numEquations, numEquations))
+ b = np.zeros((numEquations, img.shape[0]))
+ # Sum of weights assigned:
+ sum_neighbors = np.ones(numEquations)
+ for n in neighbors_weights:
+ offset, weight = n[0], n[1]
+ # Take out outliers
+ valid, new_coords = NoisyLinearImputer.add_offset_to_indices(
+ indices, offset, mask.shape)
+ valid_coords = new_coords[valid]
+ valid_ids = np.argwhere(valid == 1).flatten()
+ # Add values to the right hand-side
+ has_values_coords = valid_coords[maskflt[valid_coords] > 0.5]
+ has_values_ids = valid_ids[maskflt[valid_coords] > 0.5]
+ b[has_values_ids, :] -= weight * imgflat[:, has_values_coords].T
+ # Add weights to the system (left hand side)
+# Find coordinates in the system.
+ has_no_values = valid_coords[maskflt[valid_coords] < 0.5]
+ variable_ids = coords_to_vidx[has_no_values]
+ has_no_values_ids = valid_ids[maskflt[valid_coords] < 0.5]
+ A[has_no_values_ids, variable_ids] = weight
+ # Reduce weight for invalid
+ sum_neighbors[np.argwhere(valid == 0).flatten()] = \
+ sum_neighbors[np.argwhere(valid == 0).flatten()] - weight
+
+ A[np.arange(numEquations), np.arange(numEquations)] = -sum_neighbors
+ return A, b
+
+ def __call__(self, img: torch.Tensor, mask: torch.Tensor):
+ """ Our linear inputation scheme. """
+ """
+ This is the function to do the linear infilling
+ img: original image (C,H,W)-tensor;
+ mask: mask; (H,W)-tensor
+
+ """
+ imgflt = img.reshape(img.shape[0], -1)
+ maskflt = mask.reshape(-1)
+ # Indices that need to be imputed.
+ indices_linear = np.argwhere(maskflt == 0).flatten()
+ # Set up sparse equation system, solve system.
+ A, b = NoisyLinearImputer.setup_sparse_system(
+ mask.numpy(), img.numpy(), neighbors_weights)
+ res = torch.tensor(spsolve(csc_matrix(A), b), dtype=torch.float)
+
+ # Fill the values with the solution of the system.
+ img_infill = imgflt.clone()
+ img_infill[:, indices_linear] = res.t() + self.noise * \
+ torch.randn_like(res.t())
+
+ return img_infill.reshape_as(img)
+
+
+class ROADMostRelevantFirst(PerturbationConfidenceMetric):
+ def __init__(self, percentile=80):
+ super(ROADMostRelevantFirst, self).__init__(
+ RemoveMostRelevantFirst(percentile, NoisyLinearImputer()))
+
+
+class ROADLeastRelevantFirst(PerturbationConfidenceMetric):
+ def __init__(self, percentile=20):
+ super(ROADLeastRelevantFirst, self).__init__(
+ RemoveLeastRelevantFirst(percentile, NoisyLinearImputer()))
+
+
+class ROADMostRelevantFirstAverage(AveragerAcrossThresholds):
+ def __init__(self, percentiles=[10, 20, 30, 40, 50, 60, 70, 80, 90]):
+ super(ROADMostRelevantFirstAverage, self).__init__(
+ ROADMostRelevantFirst, percentiles)
+
+
+class ROADLeastRelevantFirstAverage(AveragerAcrossThresholds):
+ def __init__(self, percentiles=[10, 20, 30, 40, 50, 60, 70, 80, 90]):
+ super(ROADLeastRelevantFirstAverage, self).__init__(
+ ROADLeastRelevantFirst, percentiles)
+
+
+class ROADCombined:
+ def __init__(self, percentiles=[10, 20, 30, 40, 50, 60, 70, 80, 90]):
+ self.percentiles = percentiles
+ self.morf_averager = ROADMostRelevantFirstAverage(percentiles)
+ self.lerf_averager = ROADLeastRelevantFirstAverage(percentiles)
+
+ def __call__(self,
+ input_tensor: torch.Tensor,
+ cams: np.ndarray,
+ targets: List[Callable],
+ model: torch.nn.Module):
+
+ scores_lerf = self.lerf_averager(input_tensor, cams, targets, model)
+ scores_morf = self.morf_averager(input_tensor, cams, targets, model)
+ return (scores_lerf - scores_morf) / 2
diff --git a/pytorch_grad_cam/random_cam.py b/pytorch_grad_cam/random_cam.py
new file mode 100644
index 0000000..5bb6ecc
--- /dev/null
+++ b/pytorch_grad_cam/random_cam.py
@@ -0,0 +1,22 @@
+import numpy as np
+from pytorch_grad_cam.base_cam import BaseCAM
+
+
+class RandomCAM(BaseCAM):
+ def __init__(self, model, target_layers, use_cuda=False,
+ reshape_transform=None):
+ super(
+ RandomCAM,
+ self).__init__(
+ model,
+ target_layers,
+ use_cuda,
+ reshape_transform)
+
+ def get_cam_weights(self,
+ input_tensor,
+ target_layer,
+ target_category,
+ activations,
+ grads):
+ return np.random.uniform(-1, 1, size=(grads.shape[0], grads.shape[1]))
diff --git a/pytorch_grad_cam/score_cam.py b/pytorch_grad_cam/score_cam.py
new file mode 100644
index 0000000..38460c5
--- /dev/null
+++ b/pytorch_grad_cam/score_cam.py
@@ -0,0 +1,60 @@
+import torch
+import tqdm
+from pytorch_grad_cam.base_cam import BaseCAM
+
+
+class ScoreCAM(BaseCAM):
+ def __init__(
+ self,
+ model,
+ target_layers,
+ use_cuda=False,
+ reshape_transform=None):
+ super(ScoreCAM, self).__init__(model,
+ target_layers,
+ use_cuda,
+ reshape_transform=reshape_transform,
+ uses_gradients=False)
+
+ def get_cam_weights(self,
+ input_tensor,
+ target_layer,
+ targets,
+ activations,
+ grads):
+ with torch.no_grad():
+ upsample = torch.nn.UpsamplingBilinear2d(
+ size=input_tensor.shape[-2:])
+ activation_tensor = torch.from_numpy(activations)
+ if self.cuda:
+ activation_tensor = activation_tensor.cuda()
+
+ upsampled = upsample(activation_tensor)
+
+ maxs = upsampled.view(upsampled.size(0),
+ upsampled.size(1), -1).max(dim=-1)[0]
+ mins = upsampled.view(upsampled.size(0),
+ upsampled.size(1), -1).min(dim=-1)[0]
+
+ maxs, mins = maxs[:, :, None, None], mins[:, :, None, None]
+ upsampled = (upsampled - mins) / (maxs - mins)
+
+ input_tensors = input_tensor[:, None,
+ :, :] * upsampled[:, :, None, :, :]
+
+ if hasattr(self, "batch_size"):
+ BATCH_SIZE = self.batch_size
+ else:
+ BATCH_SIZE = 16
+
+ scores = []
+ for target, tensor in zip(targets, input_tensors):
+ for i in tqdm.tqdm(range(0, tensor.size(0), BATCH_SIZE)):
+ batch = tensor[i: i + BATCH_SIZE, :]
+ outputs = [target(o).cpu().item()
+ for o in self.model(batch)]
+ scores.extend(outputs)
+ scores = torch.Tensor(scores)
+ scores = scores.view(activations.shape[0], activations.shape[1])
+ weights = torch.nn.Softmax(dim=-1)(scores).numpy()
+ return weights
diff --git a/pytorch_grad_cam/sobel_cam.py b/pytorch_grad_cam/sobel_cam.py
new file mode 100644
index 0000000..84168a7
--- /dev/null
+++ b/pytorch_grad_cam/sobel_cam.py
@@ -0,0 +1,11 @@
+import cv2
+
+
+def sobel_cam(img):
+ gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
+ grad_x = cv2.Sobel(gray, cv2.CV_64F, 1, 0, ksize=3)
+ grad_y = cv2.Sobel(gray, cv2.CV_64F, 0, 1, ksize=3)
+ abs_grad_x = cv2.convertScaleAbs(grad_x)
+ abs_grad_y = cv2.convertScaleAbs(grad_y)
+ grad = cv2.addWeighted(abs_grad_x, 0.5, abs_grad_y, 0.5, 0)
+ return grad
diff --git a/pytorch_grad_cam/utils/__init__.py b/pytorch_grad_cam/utils/__init__.py
new file mode 100644
index 0000000..269a526
--- /dev/null
+++ b/pytorch_grad_cam/utils/__init__.py
@@ -0,0 +1,4 @@
+from pytorch_grad_cam.utils.image import deprocess_image
+from pytorch_grad_cam.utils.svd_on_activations import get_2d_projection
+from pytorch_grad_cam.utils import model_targets
+from pytorch_grad_cam.utils import reshape_transforms
diff --git a/pytorch_grad_cam/utils/__pycache__/__init__.cpython-37.pyc b/pytorch_grad_cam/utils/__pycache__/__init__.cpython-37.pyc
new file mode 100644
index 0000000..b2a360c
Binary files /dev/null and b/pytorch_grad_cam/utils/__pycache__/__init__.cpython-37.pyc differ
diff --git a/pytorch_grad_cam/utils/__pycache__/__init__.cpython-38.pyc b/pytorch_grad_cam/utils/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000..e08b3f3
Binary files /dev/null and b/pytorch_grad_cam/utils/__pycache__/__init__.cpython-38.pyc differ
diff --git a/pytorch_grad_cam/utils/__pycache__/find_layers.cpython-37.pyc b/pytorch_grad_cam/utils/__pycache__/find_layers.cpython-37.pyc
new file mode 100644
index 0000000..83bec3d
Binary files /dev/null and b/pytorch_grad_cam/utils/__pycache__/find_layers.cpython-37.pyc differ
diff --git a/pytorch_grad_cam/utils/__pycache__/find_layers.cpython-38.pyc b/pytorch_grad_cam/utils/__pycache__/find_layers.cpython-38.pyc
new file mode 100644
index 0000000..aef9ea6
Binary files /dev/null and b/pytorch_grad_cam/utils/__pycache__/find_layers.cpython-38.pyc differ
diff --git a/pytorch_grad_cam/utils/__pycache__/image.cpython-37.pyc b/pytorch_grad_cam/utils/__pycache__/image.cpython-37.pyc
new file mode 100644
index 0000000..359400e
Binary files /dev/null and b/pytorch_grad_cam/utils/__pycache__/image.cpython-37.pyc differ
diff --git a/pytorch_grad_cam/utils/__pycache__/image.cpython-38.pyc b/pytorch_grad_cam/utils/__pycache__/image.cpython-38.pyc
new file mode 100644
index 0000000..03b58c5
Binary files /dev/null and b/pytorch_grad_cam/utils/__pycache__/image.cpython-38.pyc differ
diff --git a/pytorch_grad_cam/utils/__pycache__/model_targets.cpython-37.pyc b/pytorch_grad_cam/utils/__pycache__/model_targets.cpython-37.pyc
new file mode 100644
index 0000000..00a453e
Binary files /dev/null and b/pytorch_grad_cam/utils/__pycache__/model_targets.cpython-37.pyc differ
diff --git a/pytorch_grad_cam/utils/__pycache__/model_targets.cpython-38.pyc b/pytorch_grad_cam/utils/__pycache__/model_targets.cpython-38.pyc
new file mode 100644
index 0000000..eb0ecd6
Binary files /dev/null and b/pytorch_grad_cam/utils/__pycache__/model_targets.cpython-38.pyc differ
diff --git a/pytorch_grad_cam/utils/__pycache__/reshape_transforms.cpython-37.pyc b/pytorch_grad_cam/utils/__pycache__/reshape_transforms.cpython-37.pyc
new file mode 100644
index 0000000..ddbc899
Binary files /dev/null and b/pytorch_grad_cam/utils/__pycache__/reshape_transforms.cpython-37.pyc differ
diff --git a/pytorch_grad_cam/utils/__pycache__/reshape_transforms.cpython-38.pyc b/pytorch_grad_cam/utils/__pycache__/reshape_transforms.cpython-38.pyc
new file mode 100644
index 0000000..24f0bed
Binary files /dev/null and b/pytorch_grad_cam/utils/__pycache__/reshape_transforms.cpython-38.pyc differ
diff --git a/pytorch_grad_cam/utils/__pycache__/svd_on_activations.cpython-37.pyc b/pytorch_grad_cam/utils/__pycache__/svd_on_activations.cpython-37.pyc
new file mode 100644
index 0000000..379b63d
Binary files /dev/null and b/pytorch_grad_cam/utils/__pycache__/svd_on_activations.cpython-37.pyc differ
diff --git a/pytorch_grad_cam/utils/__pycache__/svd_on_activations.cpython-38.pyc b/pytorch_grad_cam/utils/__pycache__/svd_on_activations.cpython-38.pyc
new file mode 100644
index 0000000..66066cb
Binary files /dev/null and b/pytorch_grad_cam/utils/__pycache__/svd_on_activations.cpython-38.pyc differ
diff --git a/pytorch_grad_cam/utils/find_layers.py b/pytorch_grad_cam/utils/find_layers.py
new file mode 100644
index 0000000..4b9e445
--- /dev/null
+++ b/pytorch_grad_cam/utils/find_layers.py
@@ -0,0 +1,30 @@
+def replace_layer_recursive(model, old_layer, new_layer):
+ for name, layer in model._modules.items():
+ if layer == old_layer:
+ model._modules[name] = new_layer
+ return True
+ elif replace_layer_recursive(layer, old_layer, new_layer):
+ return True
+ return False
+
+
+def replace_all_layer_type_recursive(model, old_layer_type, new_layer):
+ for name, layer in model._modules.items():
+ if isinstance(layer, old_layer_type):
+ model._modules[name] = new_layer
+ replace_all_layer_type_recursive(layer, old_layer_type, new_layer)
+
+
+def find_layer_types_recursive(model, layer_types):
+ def predicate(layer):
+ return type(layer) in layer_types
+ return find_layer_predicate_recursive(model, predicate)
+
+
+def find_layer_predicate_recursive(model, predicate):
+ result = []
+ for name, layer in model._modules.items():
+ if predicate(layer):
+ result.append(layer)
+ result.extend(find_layer_predicate_recursive(layer, predicate))
+ return result
diff --git a/pytorch_grad_cam/utils/image.py b/pytorch_grad_cam/utils/image.py
new file mode 100644
index 0000000..34d92ba
--- /dev/null
+++ b/pytorch_grad_cam/utils/image.py
@@ -0,0 +1,183 @@
+import matplotlib
+from matplotlib import pyplot as plt
+from matplotlib.lines import Line2D
+import cv2
+import numpy as np
+import torch
+from torchvision.transforms import Compose, Normalize, ToTensor
+from typing import List, Dict
+import math
+
+
+def preprocess_image(
+ img: np.ndarray, mean=[
+ 0.5, 0.5, 0.5], std=[
+ 0.5, 0.5, 0.5]) -> torch.Tensor:
+ preprocessing = Compose([
+ ToTensor(),
+ Normalize(mean=mean, std=std)
+ ])
+ return preprocessing(img.copy()).unsqueeze(0)
+
+
+def deprocess_image(img):
+ """ see https://github.com/jacobgil/keras-grad-cam/blob/master/grad-cam.py#L65 """
+ img = img - np.mean(img)
+ img = img / (np.std(img) + 1e-5)
+ img = img * 0.1
+ img = img + 0.5
+ img = np.clip(img, 0, 1)
+ return np.uint8(img * 255)
+
+
+def show_cam_on_image(img: np.ndarray,
+ mask: np.ndarray,
+ use_rgb: bool = False,
+ colormap: int = cv2.COLORMAP_JET,
+ image_weight: float = 0.5) -> np.ndarray:
+ """ This function overlays the cam mask on the image as an heatmap.
+ By default the heatmap is in BGR format.
+
+ :param img: The base image in RGB or BGR format.
+ :param mask: The cam mask.
+ :param use_rgb: Whether to use an RGB or BGR heatmap, this should be set to True if 'img' is in RGB format.
+ :param colormap: The OpenCV colormap to be used.
+ :param image_weight: The final result is image_weight * img + (1-image_weight) * mask.
+ :returns: The default image with the cam overlay.
+ """
+ heatmap = cv2.applyColorMap(np.uint8(255 * mask), colormap)
+ if use_rgb:
+ heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
+ heatmap = np.float32(heatmap) / 255
+
+ if np.max(img) > 1:
+ raise Exception(
+ "The input image should np.float32 in the range [0, 1]")
+
+ if image_weight < 0 or image_weight > 1:
+ raise Exception(
+ f"image_weight should be in the range [0, 1].\
+ Got: {image_weight}")
+
+ cam = (1 - image_weight) * heatmap + image_weight * img
+ cam = cam / np.max(cam)
+ return np.uint8(255 * cam)
+
+
+def create_labels_legend(concept_scores: np.ndarray,
+ labels: Dict[int, str],
+ top_k=2):
+ concept_categories = np.argsort(concept_scores, axis=1)[:, ::-1][:, :top_k]
+ concept_labels_topk = []
+ for concept_index in range(concept_categories.shape[0]):
+ categories = concept_categories[concept_index, :]
+ concept_labels = []
+ for category in categories:
+ score = concept_scores[concept_index, category]
+ label = f"{','.join(labels[category].split(',')[:3])}:{score:.2f}"
+ concept_labels.append(label)
+ concept_labels_topk.append("\n".join(concept_labels))
+ return concept_labels_topk
+
+
+def show_factorization_on_image(img: np.ndarray,
+ explanations: np.ndarray,
+ colors: List[np.ndarray] = None,
+ image_weight: float = 0.5,
+ concept_labels: List = None) -> np.ndarray:
+ """ Color code the different component heatmaps on top of the image.
+ Every component color code will be magnified according to the heatmap itensity
+ (by modifying the V channel in the HSV color space),
+ and optionally create a lagend that shows the labels.
+
+ Since different factorization component heatmaps can overlap in principle,
+ we need a strategy to decide how to deal with the overlaps.
+ This keeps the component that has a higher value in it's heatmap.
+
+ :param img: The base image RGB format.
+ :param explanations: A tensor of shape num_componetns x height x width, with the component visualizations.
+ :param colors: List of R, G, B colors to be used for the components.
+ If None, will use the gist_rainbow cmap as a default.
+ :param image_weight: The final result is image_weight * img + (1-image_weight) * visualization.
+ :concept_labels: A list of strings for every component. If this is paseed, a legend that shows
+ the labels and their colors will be added to the image.
+ :returns: The visualized image.
+ """
+ n_components = explanations.shape[0]
+ if colors is None:
+ # taken from https://github.com/edocollins/DFF/blob/master/utils.py
+ _cmap = plt.cm.get_cmap('gist_rainbow')
+ colors = [
+ np.array(
+ _cmap(i)) for i in np.arange(
+ 0,
+ 1,
+ 1.0 /
+ n_components)]
+ concept_per_pixel = explanations.argmax(axis=0)
+ masks = []
+ for i in range(n_components):
+ mask = np.zeros(shape=(img.shape[0], img.shape[1], 3))
+ mask[:, :, :] = colors[i][:3]
+ explanation = explanations[i]
+ explanation[concept_per_pixel != i] = 0
+ mask = np.uint8(mask * 255)
+ mask = cv2.cvtColor(mask, cv2.COLOR_RGB2HSV)
+ mask[:, :, 2] = np.uint8(255 * explanation)
+ mask = cv2.cvtColor(mask, cv2.COLOR_HSV2RGB)
+ mask = np.float32(mask) / 255
+ masks.append(mask)
+
+ mask = np.sum(np.float32(masks), axis=0)
+ result = img * image_weight + mask * (1 - image_weight)
+ result = np.uint8(result * 255)
+
+ if concept_labels is not None:
+ px = 1 / plt.rcParams['figure.dpi'] # pixel in inches
+ fig = plt.figure(figsize=(result.shape[1] * px, result.shape[0] * px))
+ plt.rcParams['legend.fontsize'] = int(
+ 14 * result.shape[0] / 256 / max(1, n_components / 6))
+ lw = 5 * result.shape[0] / 256
+ lines = [Line2D([0], [0], color=colors[i], lw=lw)
+ for i in range(n_components)]
+ plt.legend(lines,
+ concept_labels,
+ mode="expand",
+ fancybox=True,
+ shadow=True)
+
+ plt.tight_layout(pad=0, w_pad=0, h_pad=0)
+ plt.axis('off')
+ fig.canvas.draw()
+ data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
+ plt.close(fig=fig)
+ data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
+ data = cv2.resize(data, (result.shape[1], result.shape[0]))
+ result = np.hstack((result, data))
+ return result
+
+
+def scale_cam_image(cam, target_size=None):
+ result = []
+ for img in cam:
+ img = img - np.min(img)
+ img = img / (1e-7 + np.max(img))
+ if target_size is not None:
+ img = cv2.resize(img, target_size)
+ result.append(img)
+ result = np.float32(result)
+
+ return result
+
+
+def scale_accross_batch_and_channels(tensor, target_size):
+ batch_size, channel_size = tensor.shape[:2]
+ reshaped_tensor = tensor.reshape(
+ batch_size * channel_size, *tensor.shape[2:])
+ result = scale_cam_image(reshaped_tensor, target_size)
+ result = result.reshape(
+ batch_size,
+ channel_size,
+ target_size[1],
+ target_size[0])
+ return result
diff --git a/pytorch_grad_cam/utils/model_targets.py b/pytorch_grad_cam/utils/model_targets.py
new file mode 100644
index 0000000..489dd19
--- /dev/null
+++ b/pytorch_grad_cam/utils/model_targets.py
@@ -0,0 +1,103 @@
+import numpy as np
+import torch
+import torchvision
+
+
+class ClassifierOutputTarget:
+ def __init__(self, category):
+ self.category = category
+
+ def __call__(self, model_output):
+ if len(model_output.shape) == 1:
+ return model_output[self.category]
+ return model_output[:, self.category]
+
+
+class ClassifierOutputSoftmaxTarget:
+ def __init__(self, category):
+ self.category = category
+
+ def __call__(self, model_output):
+ if len(model_output.shape) == 1:
+ return torch.softmax(model_output, dim=-1)[self.category]
+ return torch.softmax(model_output, dim=-1)[:, self.category]
+
+
+class BinaryClassifierOutputTarget:
+ def __init__(self, category):
+ self.category = category
+
+ def __call__(self, model_output):
+ if self.category == 1:
+ sign = 1
+ else:
+ sign = -1
+ return model_output * sign
+
+
+class SoftmaxOutputTarget:
+ def __init__(self):
+ pass
+
+ def __call__(self, model_output):
+ return torch.softmax(model_output, dim=-1)
+
+
+class RawScoresOutputTarget:
+ def __init__(self):
+ pass
+
+ def __call__(self, model_output):
+ return model_output
+
+
+class SemanticSegmentationTarget:
+ """ Gets a binary spatial mask and a category,
+ And return the sum of the category scores,
+ of the pixels in the mask. """
+
+ def __init__(self, category, mask):
+ self.category = category
+ self.mask = torch.from_numpy(mask)
+ if torch.cuda.is_available():
+ self.mask = self.mask.cuda()
+
+ def __call__(self, model_output):
+ return (model_output[self.category, :, :] * self.mask).sum()
+
+
+class FasterRCNNBoxScoreTarget:
+ """ For every original detected bounding box specified in "bounding boxes",
+ assign a score on how the current bounding boxes match it,
+ 1. In IOU
+ 2. In the classification score.
+ If there is not a large enough overlap, or the category changed,
+ assign a score of 0.
+
+ The total score is the sum of all the box scores.
+ """
+
+ def __init__(self, labels, bounding_boxes, iou_threshold=0.5):
+ self.labels = labels
+ self.bounding_boxes = bounding_boxes
+ self.iou_threshold = iou_threshold
+
+ def __call__(self, model_outputs):
+ output = torch.Tensor([0])
+ if torch.cuda.is_available():
+ output = output.cuda()
+
+ if len(model_outputs["boxes"]) == 0:
+ return output
+
+ for box, label in zip(self.bounding_boxes, self.labels):
+ box = torch.Tensor(box[None, :])
+ if torch.cuda.is_available():
+ box = box.cuda()
+
+ ious = torchvision.ops.box_iou(box, model_outputs["boxes"])
+ index = ious.argmax()
+ if ious[0, index] > self.iou_threshold and model_outputs["labels"][index] == label:
+ score = ious[0, index] + model_outputs["scores"][index]
+ output = output + score
+ return output
diff --git a/pytorch_grad_cam/utils/reshape_transforms.py b/pytorch_grad_cam/utils/reshape_transforms.py
new file mode 100644
index 0000000..509f092
--- /dev/null
+++ b/pytorch_grad_cam/utils/reshape_transforms.py
@@ -0,0 +1,34 @@
+import torch
+
+
+def fasterrcnn_reshape_transform(x):
+ target_size = x['pool'].size()[-2:]
+ activations = []
+ for key, value in x.items():
+ activations.append(
+ torch.nn.functional.interpolate(
+ torch.abs(value),
+ target_size,
+ mode='bilinear'))
+ activations = torch.cat(activations, axis=1)
+ return activations
+
+
+def swinT_reshape_transform(tensor, height=7, width=7):
+ result = tensor.reshape(tensor.size(0),
+ height, width, tensor.size(2))
+
+ # Bring the channels to the first dimension,
+ # like in CNNs.
+ result = result.transpose(2, 3).transpose(1, 2)
+ return result
+
+
+def vit_reshape_transform(tensor, height=14, width=14):
+ result = tensor[:, 1:, :].reshape(tensor.size(0),
+ height, width, tensor.size(2))
+
+ # Bring the channels to the first dimension,
+ # like in CNNs.
+ result = result.transpose(2, 3).transpose(1, 2)
+ return result
diff --git a/pytorch_grad_cam/utils/svd_on_activations.py b/pytorch_grad_cam/utils/svd_on_activations.py
new file mode 100644
index 0000000..a406aee
--- /dev/null
+++ b/pytorch_grad_cam/utils/svd_on_activations.py
@@ -0,0 +1,19 @@
+import numpy as np
+
+
+def get_2d_projection(activation_batch):
+ # TBD: use pytorch batch svd implementation
+ activation_batch[np.isnan(activation_batch)] = 0
+ projections = []
+ for activations in activation_batch:
+ reshaped_activations = (activations).reshape(
+ activations.shape[0], -1).transpose()
+ # Centering before the SVD seems to be important here,
+ # Otherwise the image returned is negative
+ reshaped_activations = reshaped_activations - \
+ reshaped_activations.mean(axis=0)
+ U, S, VT = np.linalg.svd(reshaped_activations, full_matrices=True)
+ projection = reshaped_activations @ VT[0, :]
+ projection = projection.reshape(activations.shape[1:])
+ projections.append(projection)
+ return np.float32(projections)
diff --git a/pytorch_grad_cam/xgrad_cam.py b/pytorch_grad_cam/xgrad_cam.py
new file mode 100644
index 0000000..81a920f
--- /dev/null
+++ b/pytorch_grad_cam/xgrad_cam.py
@@ -0,0 +1,31 @@
+import numpy as np
+from pytorch_grad_cam.base_cam import BaseCAM
+
+
+class XGradCAM(BaseCAM):
+ def __init__(
+ self,
+ model,
+ target_layers,
+ use_cuda=False,
+ reshape_transform=None):
+ super(
+ XGradCAM,
+ self).__init__(
+ model,
+ target_layers,
+ use_cuda,
+ reshape_transform)
+
+ def get_cam_weights(self,
+ input_tensor,
+ target_layer,
+ target_category,
+ activations,
+ grads):
+ sum_activations = np.sum(activations, axis=(2, 3))
+ eps = 1e-7
+ weights = grads * activations / \
+ (sum_activations[:, :, None, None] + eps)
+ weights = weights.sum(axis=(2, 3))
+ return weights
diff --git a/restful_main.py b/restful_main.py
new file mode 100644
index 0000000..4d6fca2
--- /dev/null
+++ b/restful_main.py
@@ -0,0 +1,52 @@
+from flask import Flask, jsonify
+from flask_cors import CORS
+from flask_restful import Api, Resource, reqparse
+from utils import get_log
+from task.image_interpretability import ImageInterpretability
+import ast
+
+app = Flask(__name__, static_folder='', static_url_path='')
+app.config['UPLOAD_FOLDER'] = "./upload_data"
+cors = CORS(app, resources={r"*": {"origins": "*"}})
+api = Api(app)
+
+
+class Interpretability2Image(Resource):
+
+ def post(self):
+ parser = reqparse.RequestParser()
+ parser.add_argument('image_name', type=str, required=True, default='', help='')
+ parser.add_argument('image_path', type=str, required=True, default='', help='')
+ parser.add_argument('Interpretability_method', type=str, required=True, default='textfooler', help='')
+ parser.add_argument('model_info', type=str, required=True, default={}, help='')
+ parser.add_argument('output_path', type=str, required=True, default='', help='')
+ parser.add_argument('kwargs', type=str, required=True, default={}, help='')
+ args = parser.parse_args()
+ print(args)
+
+ args.model_info = ast.literal_eval(args.model_info)
+ args.dataset_info = ast.literal_eval(args.dataset_info)
+
+ Interpretability = ImageInterpretability()
+ rst = Interpretability.perform(
+ image_name=args.model_info.image_name,
+ image_path=args.model_info.image_path,
+ Interpretability_method=args.Interpretability_method,
+ model_name=args.model_info.model_name,
+ output_path=args.output_path,
+ **args.kwargs
+ )
+ return jsonify(rst)
+
+ def get(self):
+ msg = get_log(log_path=Interpretability2Image.LOG_PATH)
+ if msg:
+ return jsonify({'status': 1, 'log': msg})
+ else:
+ return jsonify({'status': 0, 'log': None})
+
+
+api.add_resource(Interpretability2Image, '/Interpretability2Image')
+
+if __name__ == '__main__':
+ app.run(host='0.0.0.0', port=5002, debug=True)
\ No newline at end of file
diff --git a/sample/ILSVRC2012_val_00000001.JPEG b/sample/ILSVRC2012_val_00000001.JPEG
new file mode 100644
index 0000000..fd3a93f
Binary files /dev/null and b/sample/ILSVRC2012_val_00000001.JPEG differ
diff --git a/sample/ILSVRC2012_val_00000002.JPEG b/sample/ILSVRC2012_val_00000002.JPEG
new file mode 100644
index 0000000..543f639
Binary files /dev/null and b/sample/ILSVRC2012_val_00000002.JPEG differ
diff --git a/sample/ILSVRC2012_val_00000003.JPEG b/sample/ILSVRC2012_val_00000003.JPEG
new file mode 100644
index 0000000..b32c5ed
Binary files /dev/null and b/sample/ILSVRC2012_val_00000003.JPEG differ
diff --git a/sample/ILSVRC2012_val_00000004.JPEG b/sample/ILSVRC2012_val_00000004.JPEG
new file mode 100644
index 0000000..182189a
Binary files /dev/null and b/sample/ILSVRC2012_val_00000004.JPEG differ
diff --git a/sample/ILSVRC2012_val_00000005.JPEG b/sample/ILSVRC2012_val_00000005.JPEG
new file mode 100644
index 0000000..a68ef17
Binary files /dev/null and b/sample/ILSVRC2012_val_00000005.JPEG differ
diff --git a/sample/ILSVRC2012_val_00000006.JPEG b/sample/ILSVRC2012_val_00000006.JPEG
new file mode 100644
index 0000000..f284522
Binary files /dev/null and b/sample/ILSVRC2012_val_00000006.JPEG differ
diff --git a/sample/ILSVRC2012_val_00000007.JPEG b/sample/ILSVRC2012_val_00000007.JPEG
new file mode 100644
index 0000000..be25a8c
Binary files /dev/null and b/sample/ILSVRC2012_val_00000007.JPEG differ
diff --git a/sample/ILSVRC2012_val_00000008.JPEG b/sample/ILSVRC2012_val_00000008.JPEG
new file mode 100644
index 0000000..6b8ba97
Binary files /dev/null and b/sample/ILSVRC2012_val_00000008.JPEG differ
diff --git a/sample/ILSVRC2012_val_00000009.JPEG b/sample/ILSVRC2012_val_00000009.JPEG
new file mode 100644
index 0000000..9f6c9c0
Binary files /dev/null and b/sample/ILSVRC2012_val_00000009.JPEG differ
diff --git a/sample/both.png b/sample/both.png
new file mode 100644
index 0000000..61bdcae
Binary files /dev/null and b/sample/both.png differ
diff --git a/sample/horses.jpg b/sample/horses.jpg
new file mode 100644
index 0000000..a82d6a6
Binary files /dev/null and b/sample/horses.jpg differ
diff --git a/sample/img.png b/sample/img.png
new file mode 100644
index 0000000..28db411
Binary files /dev/null and b/sample/img.png differ
diff --git a/sample/img_1.png b/sample/img_1.png
new file mode 100644
index 0000000..6e3eeb0
Binary files /dev/null and b/sample/img_1.png differ
diff --git a/sample/img_2.png b/sample/img_2.png
new file mode 100644
index 0000000..f15a71b
Binary files /dev/null and b/sample/img_2.png differ
diff --git a/task/__pycache__/image_interpretability.cpython-38.pyc b/task/__pycache__/image_interpretability.cpython-38.pyc
new file mode 100644
index 0000000..1be1e48
Binary files /dev/null and b/task/__pycache__/image_interpretability.cpython-38.pyc differ
diff --git a/task/image_interpretability.py b/task/image_interpretability.py
new file mode 100644
index 0000000..4e9b297
--- /dev/null
+++ b/task/image_interpretability.py
@@ -0,0 +1,238 @@
+import argparse
+import cv2
+import numpy as np
+import torch
+import json
+from torchvision import models
+from importlib import import_module
+from utils import LogHelper
+from pytorch_grad_cam import GradCAM, \
+ HiResCAM, \
+ ScoreCAM, \
+ GradCAMPlusPlus, \
+ AblationCAM, \
+ XGradCAM, \
+ EigenCAM, \
+ EigenGradCAM, \
+ LayerCAM, \
+ FullGrad, \
+ GradCAMElementWise
+from pytorch_grad_cam.utils.find_layers import find_layer_types_recursive
+from pytorch_grad_cam import GuidedBackpropReLUModel
+from pytorch_grad_cam.utils.image import show_cam_on_image, \
+ deprocess_image, \
+ preprocess_image
+from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
+from PIL import Image
+from torchvision import transforms
+import urllib.request
+import json
+
+
+# Load labels
+with open("imagenet_1000.json") as f:
+ labels = json.load(f)
+
+
+
+available_device = 'cuda'
+def get_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--use-cuda', action='store_true', default=True,
+ help='Use NVIDIA GPU acceleration')
+ parser.add_argument('--aug_smooth', action='store_true',
+ help='Apply test time augmentation to smooth the CAM')
+ parser.add_argument(
+ '--eigen_smooth',
+ action='store_true',
+ help='Reduce noise by taking the first principle componenet'
+ 'of cam_weights*activations')
+ args = parser.parse_args()
+ args.use_cuda = args.use_cuda and torch.cuda.is_available()
+ if args.use_cuda:
+ print('Using GPU for acceleration')
+ else:
+ print('Using CPU for computation')
+
+ return args
+
+
+class ImageInterpretability():
+ def __init__(self):
+ super().__init__()
+
+
+ def perform(self,image_path: str, method: str, model_info: dict,output_path: str,log_path=None,**kwargs):
+ # aug_smooth, eigen_smooth,
+ args = get_args()
+ '''
+ 图片输入地址
+ :param image_path: (type=str, required=True) value= (图片地址)
+
+ 可解释性算法
+ :param method: (type=str, required=True) value=['gradcam', 'hirescam', 'scorecam', 'gradcam++',
+ 'ablationcam', 'xgradcam', 'eigencam', 'eigengradcam', 'layercam', 'gradcamelementwise','fullgrad'] (可解释性方法)
+
+ :param model_name: (type=str, required=True) value=[resnet, vgg, densenet, mnasnet] (模型名称)
+
+ :param model: (type=nn.Module, required=True) value=[resnet18, resnet50, vgg11, vgg13, vgg16, vgg19, densenet161, mnasnet1_0] (模型)
+
+ kwargs:非必要传入参数,在特定要求下传入
+ {
+ aug_smooth:默认采用数据增强技术来改善cam质量
+ (type=bool) value=[Ture, False]
+
+ eigen_smooth:计算CAM(类激活映射)权重和激活之间的矩阵乘积,然后提取该结果的第一个主成分来减少结果中的噪音。
+ (type=bool) value=[Ture, False]
+ }
+
+
+ '''
+ methods = \
+ {"gradcam": GradCAM,
+ "hirescam": HiResCAM,
+ "scorecam": ScoreCAM,
+ "gradcam++": GradCAMPlusPlus,
+ "ablationcam": AblationCAM,
+ "xgradcam": XGradCAM,
+ "eigencam": EigenCAM,
+ "eigengradcam": EigenGradCAM,
+ "layercam": LayerCAM,
+ "fullgrad": FullGrad,
+ "gradcamelementwise": GradCAMElementWise}
+
+ # model = models.resnet50(pretrained=True)
+ # model_class = getattr(models,model_info.model_name)
+ # model = model_class(pretrained=True)
+ model = self.get_model(info=model_info, device=available_device)
+ # attack_log = LogHelper(log_path=log_path, root_log_name='aitest').build_new_log()
+
+ if 'resnet' in model_info.get('model_name').lower():
+ target_layers = [model.layer4]
+ elif 'vgg' in model_info.get('model_name').lower():
+ target_layers = [model.features[-1]]
+ elif 'densenet' in model_info.get('model_name').lower():
+ target_layers = [model.features[-1]]
+ elif 'mnasnet' in model_info.get('model_name').lower():
+ target_layers = [model.layers[-1]]
+ else:
+ target_layers = find_layer_types_recursive(model, [torch.nn.ReLU])
+
+
+ rgb_img = cv2.imread(image_path, 1)[:, :, ::-1]
+ rgb_img = np.float32(rgb_img) / 255
+ input_tensor = preprocess_image(rgb_img,
+ mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225])
+
+
+ targets = None
+
+ cam_algorithm = methods[method]
+ with cam_algorithm(model=model,
+ target_layers=target_layers,
+ use_cuda=args.use_cuda) as cam:
+
+ # AblationCAM and ScoreCAM have batched implementations.
+ # You can override the internal batch size for faster computation.
+ cam.batch_size = 32
+ aug_smooth = kwargs['aug_smooth']
+ eigen_smooth = kwargs['eigen_smooth']
+ grayscale_cam = cam(input_tensor=input_tensor,
+ targets=targets,
+ aug_smooth=aug_smooth,
+ eigen_smooth=eigen_smooth)
+
+ # Here grayscale_cam has only one image in the batch
+ grayscale_cam = grayscale_cam[0, :]
+
+ cam_image = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True)
+
+ # cam_image is RGB encoded whereas "cv2.imwrite" requires BGR encoding.
+ cam_image = cv2.cvtColor(cam_image, cv2.COLOR_RGB2BGR)
+
+ gb_model = GuidedBackpropReLUModel(model=model, use_cuda=args.use_cuda)
+ gb = gb_model(input_tensor, target_category=None)
+
+ cam_mask = cv2.merge([grayscale_cam, grayscale_cam, grayscale_cam])
+ cam_gb = deprocess_image(cam_mask * gb)
+ gb = deprocess_image(gb)
+
+ model.eval()
+ image = Image.open(image_path)
+ transform = transforms.Compose([
+ transforms.Resize(256),
+ transforms.CenterCrop(224),
+ transforms.ToTensor(),
+ transforms.Normalize(
+ mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225]
+ )
+ ])
+ image = transform(image)
+ # Perform forward pass
+ with torch.no_grad():
+ output = model(image.unsqueeze(0))
+
+ # Convert output to probabilities
+ probabilities = torch.softmax(output[0], dim=0)
+ # Get predicted label and probability
+ predicted_class = torch.argmax(probabilities).item()
+ probability = probabilities[predicted_class].item()
+
+ # Get predicted label name
+ predicted_class = torch.argmax(probabilities).item()
+ label_name = labels[str(predicted_class)]
+ #print(f"Predicted label: {label_name}")
+
+ #imwrite不支持带有中文路径的地址
+ cv2.imwrite(output_path+label_name+'-'+method+"_cam3.jpg", cam_image)
+ cv2.imwrite(output_path+label_name+'-'+method+"_gb.jpg", gb)
+ cv2.imwrite(output_path+label_name+'-'+method+"_cam_gb.jpg",cam_gb)
+
+ return {'output_path':output_path,'Predicted_class': label_name, 'Probability': probability}
+
+ @staticmethod
+ def get_model(info, device):
+ # if isinstance(info, dict):
+ # info = argparse.Namespace(**info)
+ #
+ # if isinstance(info.path, dict):
+ # if isinstance(info.path, str):
+ # info.path = json.loads(info.path)
+
+ # load pytorch model from user upload files
+ # if info.ownership != 'aitest' and ('upload' in info.type and 'pytorch' in info.type):
+ # return load_pytorch_model(state_dict_path=info.path['parameter_file'], device=device, net_file_path=info.path['structure_file'])
+
+ # load built_in text model from textattack
+ if 'torchvision' == info.get('source'):
+ model_class = getattr(models, info.get('model_name'))
+ return model_class(pretrained=True)
+
+
+
+def load_pytorch_model(state_dict_path, model_class_name, device, net_file_path=None):
+ """
+ model = load_user_model('word_cnn_for_classification',
+ 'C:\\Users\\pcl\\.cache\\textattack\\models\\classification\\cnn\\rotten-tomatoes\\pytorch_model.bin',
+ 'WordCNNForClassification')
+ """
+ if net_file_path:
+ net_file_path = str(net_file_path).replace('\\', '/')
+ net_file_path = net_file_path.replace('./', '').replace('/', '.')
+ model_class = getattr(import_module(net_file_path), model_class_name)
+ model = model_class()
+ model.load(state_dict_path, device)
+ # tokenizer = model.tokenizer
+ # model = textattack.models.wrappers.PyTorchModelWrapper(
+ # model, tokenizer
+ # )
+ return model
+ else:
+ return torch.load(state_dict_path, device)
+
+
+
+
+
diff --git a/test_image_interpretability.py b/test_image_interpretability.py
new file mode 100644
index 0000000..7a7ea7c
--- /dev/null
+++ b/test_image_interpretability.py
@@ -0,0 +1,27 @@
+from task.image_interpretability import ImageInterpretability
+
+
+# # TODO: 使用gradcam算法,针对resnet模型,利用本地图片both.png进行可解释性分析,参数为默认参数
+# files_path = ImageInterpretability().perform(image_path='sample/both.png',method='gradcam', model_name='resnet',output_path='D:\桌面\image_interprebility\image_interprebility\image_interprebility/test')
+# print(files_path)
+#
+kwargs={
+ "aug_smooth":True,
+ "eigen_smooth":True
+ }
+model_info ={"model_name":'resnet50',
+ "source": "torchvision"
+ }
+# TODO: 使用fullgrad算法,针对resnet模型,利用本地图片both.png进行可解释性分析,参数为默认参数
+files_path = ImageInterpretability().perform(image_path='sample/img_2.png',method='fullgrad',model_info=model_info,output_path="D:/test_image/",**kwargs)
+print(files_path)
+
+
+# kwargs={
+# "target_layer":'',
+# "aug_smooth":True,
+# "eigen_smooth":True
+# }
+# # TODO: 使用fullgrad算法,针对resnet模型,利用本地图片both.png进行可解释性分析,采用数据增强技术来改善cam质量,并使用提取主成分的方式减少噪声,通过改变target_layer得到不同层的结果
+# files_path = ImageInterpretability().perform(image_path='sample/ILSVRC2012_val_00000002.JPEG',method='gradcam', model_name='resnet',**kwargs)
+# print(files_path)
diff --git a/utils.py b/utils.py
new file mode 100644
index 0000000..7029047
--- /dev/null
+++ b/utils.py
@@ -0,0 +1,88 @@
+import logging
+import os
+import torch
+
+device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+# aitest data
+MODEL_SAVE_PATH = 'data/aitest/models'
+ADV_EXAMPLES_SAVE_PATH = 'data/aitest/adv_examples'
+DATASET_SAVE_PATH = 'data/aitest/datasets'
+FILES_SAVE_PATH = 'data/aitest/files'
+
+
+# text_attack configuration
+USE_ENCODING_PATH = "data/aitest/models/tfhub_modules/063d866c06683311b44b4992fd46003be952409c"
+TEXTATTACK_CACHE_DIR = 'data/aitest/models/textattack'
+
+# text robust analysis
+GPT2_PATH = 'data/aitest/models/gpt-2'
+NLTK_PATH = 'data/aitest/files/nltk_data'
+
+
+class LogHelper:
+
+ LOG_STRING = "\033[34;1mAItest\033[0m"
+
+ def __init__(self, log_path, root_log_name=None):
+ self.log_path = log_path
+
+ if root_log_name:
+ self.log_name = f"{root_log_name}.Interpretability_Image"
+ self.logger = logging.getLogger(f"{root_log_name}.Interpretability_Image")
+ else:
+ self.log_name = 'aitest_Interpretability_Image'
+ self.logger = logging.getLogger('aitest_Interpretability_Image')
+ formatter = logging.Formatter(f"{LogHelper.LOG_STRING}: %(message)s")
+ stream_handler = logging.StreamHandler('')
+ stream_handler.setFormatter(formatter)
+ self.logger.addHandler(stream_handler)
+
+ log_formatter = logging.Formatter("%(asctime)s : %(message)s")
+ log_filter = logging.Filter(name=self.log_name)
+ log_handler = logging.FileHandler(self.log_path)
+ log_handler.setLevel(level=logging.INFO)
+ log_handler.setFormatter(log_formatter)
+ log_handler.addFilter(log_filter)
+ self.fh = log_handler
+
+ def build_new_log(self):
+ self.logger.addHandler(self.fh)
+ return self
+
+ def insert_log(self, content=None):
+ if content:
+ self.logger.info(content)
+
+ def finish_log(self):
+ self.logger.removeHandler(self.fh)
+
+
+def get_log(log_path):
+ if os.path.exists(log_path):
+ with open(log_path, 'r+', encoding='utf-8') as f:
+ texts = f.read()
+ f.truncate(0)
+ return texts
+ return None
+
+
+def get_size(path, show_bytes=True):
+ size = 0
+ if os.path.isfile(path):
+ size = os.path.getsize(path)
+ elif os.path.isdir(path):
+
+ for root, dirs, files in os.walk(path):
+ for file in files:
+ size += os.path.getsize(os.path.join(root, file))
+ if show_bytes:
+ return int(size)
+
+ for x in ['bytes', 'KB', 'MB', 'GB', 'TB']:
+ if size < 1024.0:
+ return "%3.1f %s" % (size, x)
+ size /= 1024.0
+ return size
+
+