commit 478335a88525c4f6ed70d3fe070a6db4d4febcdf Author: zhouminyang <784185055@qq.com> Date: Mon Jun 5 15:11:03 2023 +0800 1.0 diff --git a/.idea/.gitignore b/.idea/.gitignore new file mode 100644 index 0000000..13566b8 --- /dev/null +++ b/.idea/.gitignore @@ -0,0 +1,8 @@ +# Default ignored files +/shelf/ +/workspace.xml +# Editor-based HTTP Client requests +/httpRequests/ +# Datasource local storage ignored files +/dataSources/ +/dataSources.local.xml diff --git a/.idea/image_interprebility.iml b/.idea/image_interprebility.iml new file mode 100644 index 0000000..2b10bc0 --- /dev/null +++ b/.idea/image_interprebility.iml @@ -0,0 +1,12 @@ + + + + + + + + + + \ No newline at end of file diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml new file mode 100644 index 0000000..1d8b614 --- /dev/null +++ b/.idea/inspectionProfiles/Project_Default.xml @@ -0,0 +1,18 @@ + + + + \ No newline at end of file diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml new file mode 100644 index 0000000..105ce2d --- /dev/null +++ b/.idea/inspectionProfiles/profiles_settings.xml @@ -0,0 +1,6 @@ + + + + \ No newline at end of file diff --git a/.idea/misc.xml b/.idea/misc.xml new file mode 100644 index 0000000..cbc002c --- /dev/null +++ b/.idea/misc.xml @@ -0,0 +1,4 @@ + + + + \ No newline at end of file diff --git a/.idea/modules.xml b/.idea/modules.xml new file mode 100644 index 0000000..13620b5 --- /dev/null +++ b/.idea/modules.xml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/__pycache__/utils.cpython-38.pyc b/__pycache__/utils.cpython-38.pyc new file mode 100644 index 0000000..361d507 Binary files /dev/null and b/__pycache__/utils.cpython-38.pyc differ diff --git a/api.py b/api.py new file mode 100644 index 0000000..5de005e --- /dev/null +++ b/api.py @@ -0,0 +1,148 @@ +import argparse +import cv2 +import numpy as np +import torch +from torchvision import models +from pytorch_grad_cam import GradCAM, \ + HiResCAM, \ + ScoreCAM, \ + GradCAMPlusPlus, \ + AblationCAM, \ + XGradCAM, \ + EigenCAM, \ + EigenGradCAM, \ + LayerCAM, \ + FullGrad, \ + GradCAMElementWise + + +from pytorch_grad_cam import GuidedBackpropReLUModel +from pytorch_grad_cam.utils.image import show_cam_on_image, \ + deprocess_image, \ + preprocess_image +from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--use-cuda', action='store_true', default=False, + help='Use NVIDIA GPU acceleration') + parser.add_argument( + '--image-path', + type=str, + default='./examples/both.png', + help='Input image path') + parser.add_argument('--aug_smooth', action='store_true', + help='Apply test time augmentation to smooth the CAM') + parser.add_argument( + '--eigen_smooth', + action='store_true', + help='Reduce noise by taking the first principle componenet' + 'of cam_weights*activations') + parser.add_argument('--method', type=str, default='gradcam', + choices=['gradcam', 'hirescam', 'gradcam++', + 'scorecam', 'xgradcam', + 'ablationcam', 'eigencam', + 'eigengradcam', 'layercam', 'fullgrad'], + help='Can be gradcam/gradcam++/scorecam/xgradcam' + '/ablationcam/eigencam/eigengradcam/layercam') + + args = parser.parse_args() + args.use_cuda = args.use_cuda and torch.cuda.is_available() + if args.use_cuda: + print('Using GPU for acceleration') + else: + print('Using CPU for computation') + + return args + + +def api(image_path,method,model_name,**kwargs): + args = get_args() + methods = \ + {"gradcam": GradCAM, + "hirescam": HiResCAM, + "scorecam": ScoreCAM, + "gradcam++": GradCAMPlusPlus, + "ablationcam": AblationCAM, + "xgradcam": XGradCAM, + "eigencam": EigenCAM, + "eigengradcam": EigenGradCAM, + "layercam": LayerCAM, + "fullgrad": FullGrad, + "gradcamelementwise": GradCAMElementWise} + + model = models.resnet50(pretrained=True) + # model = eval('models.'+model_name+'(pretrained=True)') + print(model) + # Choose the target layer you want to compute the visualization for. + # Usually this will be the last convolutional layer in the model. + # Some common choices can be: + # Resnet18 and 50: model.layer4 + # VGG, densenet161: model.features[-1] + # mnasnet1_0: model.layers[-1] + # You can print the model to help chose the layer + # You can pass a list with several target layers, + # in that case the CAMs will be computed per layer and then aggregated. + # You can also try selecting all layers of a certain type, with e.g: + # from pytorch_grad_cam.utils.find_layers import find_layer_types_recursive + # find_layer_types_recursive(model, [torch.nn.ReLU]) + #target_layers = [model.layer4] + target_layer=kwargs['target_layer'] + target_layers = [eval(f'model.{target_layer}')] + + rgb_img = cv2.imread(image_path, 1)[:, :, ::-1] + rgb_img = np.float32(rgb_img) / 255 + input_tensor = preprocess_image(rgb_img, + mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]) + + # We have to specify the target we want to generate + # the Class Activation Maps for. + # If targets is None, the highest scoring category (for every member in the batch) will be used. + # You can target specific categories by + # targets = [e.g ClassifierOutputTarget(281)] + targets = None + + # Using the with statement ensures the context is freed, and you can + # recreate different CAM objects in a loop. + cam_algorithm = methods[method] + with cam_algorithm(model=model, + target_layers=target_layers, + use_cuda=args.use_cuda) as cam: + + # AblationCAM and ScoreCAM have batched implementations. + # You can override the internal batch size for faster computation. + cam.batch_size = 32 + print(args.eigen_smooth) + aug_smooth=kwargs['aug_smooth'] + grayscale_cam = cam(input_tensor=input_tensor, + targets=targets, + aug_smooth=aug_smooth, + eigen_smooth=args.eigen_smooth) + + # Here grayscale_cam has only one image in the batch + grayscale_cam = grayscale_cam[0, :] + + cam_image = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True) + + # cam_image is RGB encoded whereas "cv2.imwrite" requires BGR encoding. + cam_image = cv2.cvtColor(cam_image, cv2.COLOR_RGB2BGR) + + gb_model = GuidedBackpropReLUModel(model=model, use_cuda=args.use_cuda) + gb = gb_model(input_tensor, target_category=None) + + cam_mask = cv2.merge([grayscale_cam, grayscale_cam, grayscale_cam]) + cam_gb = deprocess_image(cam_mask * gb) + gb = deprocess_image(gb) + + cv2.imwrite(f'{method}_cam.jpg', cam_image) + cv2.imwrite(f'{method}_gb.jpg', gb) + cv2.imwrite(f'{method}_cam_gb.jpg', cam_gb) + return method+'_gb.jpg' + +kwargs={"target_layer":'layer1', + "aug_smooth":True, + "eigen_smooth":True} +path=api('sample/both.png','fullgrad','resnet',**kwargs) +print(path) \ No newline at end of file diff --git a/imagenet_1000.json b/imagenet_1000.json new file mode 100644 index 0000000..0c06838 --- /dev/null +++ b/imagenet_1000.json @@ -0,0 +1 @@ +{"0": "tench, Tinca tinca", "1": "goldfish, Carassius auratus", "2": "great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias", "3": "tiger shark, Galeocerdo cuvieri", "4": "hammerhead, hammerhead shark", "5": "electric ray, crampfish, numbfish, torpedo", "6": "stingray", "7": "cock", "8": "hen", "9": "ostrich, Struthio camelus", "10": "brambling, Fringilla montifringilla", "11": "goldfinch, Carduelis carduelis", "12": "house finch, linnet, Carpodacus mexicanus", "13": "junco, snowbird", "14": "indigo bunting, indigo finch, indigo bird, Passerina cyanea", "15": "robin, American robin, Turdus migratorius", "16": "bulbul", "17": "jay", "18": "magpie", "19": "chickadee", "20": "water ouzel, dipper", "21": "kite", "22": "bald eagle, American eagle, Haliaeetus leucocephalus", "23": "vulture", "24": "great grey owl, great gray owl, Strix nebulosa", "25": "European fire salamander, Salamandra salamandra", "26": "common newt, Triturus vulgaris", "27": "eft", "28": "spotted salamander, Ambystoma maculatum", "29": "axolotl, mud puppy, Ambystoma mexicanum", "30": "bullfrog, Rana catesbeiana", "31": "tree frog, tree-frog", "32": "tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui", "33": "loggerhead, loggerhead turtle, Caretta caretta", "34": "leatherback turtle, leatherback, leathery turtle, Dermochelys coriacea", "35": "mud turtle", "36": "terrapin", "37": "box turtle, box tortoise", "38": "banded gecko", "39": "common iguana, iguana, Iguana iguana", "40": "American chameleon, anole, Anolis carolinensis", "41": "whiptail, whiptail lizard", "42": "agama", "43": "frilled lizard, Chlamydosaurus kingi", "44": "alligator lizard", "45": "Gila monster, Heloderma suspectum", "46": "green lizard, Lacerta viridis", "47": "African chameleon, Chamaeleo chamaeleon", "48": "Komodo dragon, Komodo lizard, dragon lizard, giant lizard, Varanus komodoensis", "49": "African crocodile, Nile crocodile, Crocodylus niloticus", "50": "American alligator, Alligator mississipiensis", "51": "triceratops", "52": "thunder snake, worm snake, Carphophis amoenus", "53": "ringneck snake, ring-necked snake, ring snake", "54": "hognose snake, puff adder, sand viper", "55": "green snake, grass snake", "56": "king snake, kingsnake", "57": "garter snake, grass snake", "58": "water snake", "59": "vine snake", "60": "night snake, Hypsiglena torquata", "61": "boa constrictor, Constrictor constrictor", "62": "rock python, rock snake, Python sebae", "63": "Indian cobra, Naja naja", "64": "green mamba", "65": "sea snake", "66": "horned viper, cerastes, sand viper, horned asp, Cerastes cornutus", "67": "diamondback, diamondback rattlesnake, Crotalus adamanteus", "68": "sidewinder, horned rattlesnake, Crotalus cerastes", "69": "trilobite", "70": "harvestman, daddy longlegs, Phalangium opilio", "71": "scorpion", "72": "black and gold garden spider, Argiope aurantia", "73": "barn spider, Araneus cavaticus", "74": "garden spider, Aranea diademata", "75": "black widow, Latrodectus mactans", "76": "tarantula", "77": "wolf spider, hunting spider", "78": "tick", "79": "centipede", "80": "black grouse", "81": "ptarmigan", "82": "ruffed grouse, partridge, Bonasa umbellus", "83": "prairie chicken, prairie grouse, prairie fowl", "84": "peacock", "85": "quail", "86": "partridge", "87": "African grey, African gray, Psittacus erithacus", "88": "macaw", "89": "sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita", "90": "lorikeet", "91": "coucal", "92": "bee eater", "93": "hornbill", "94": "hummingbird", "95": "jacamar", "96": "toucan", "97": "drake", "98": "red-breasted merganser, Mergus serrator", "99": "goose", "100": "black swan, Cygnus atratus", "101": "tusker", "102": "echidna, spiny anteater, anteater", "103": "platypus, duckbill, duckbilled platypus, duck-billed platypus, Ornithorhynchus anatinus", "104": "wallaby, brush kangaroo", "105": "koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus", "106": "wombat", "107": "jellyfish", "108": "sea anemone, anemone", "109": "brain coral", "110": "flatworm, platyhelminth", "111": "nematode, nematode worm, roundworm", "112": "conch", "113": "snail", "114": "slug", "115": "sea slug, nudibranch", "116": "chiton, coat-of-mail shell, sea cradle, polyplacophore", "117": "chambered nautilus, pearly nautilus, nautilus", "118": "Dungeness crab, Cancer magister", "119": "rock crab, Cancer irroratus", "120": "fiddler crab", "121": "king crab, Alaska crab, Alaskan king crab, Alaska king crab, Paralithodes camtschatica", "122": "American lobster, Northern lobster, Maine lobster, Homarus americanus", "123": "spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish", "124": "crayfish, crawfish, crawdad, crawdaddy", "125": "hermit crab", "126": "isopod", "127": "white stork, Ciconia ciconia", "128": "black stork, Ciconia nigra", "129": "spoonbill", "130": "flamingo", "131": "little blue heron, Egretta caerulea", "132": "American egret, great white heron, Egretta albus", "133": "bittern", "134": "crane", "135": "limpkin, Aramus pictus", "136": "European gallinule, Porphyrio porphyrio", "137": "American coot, marsh hen, mud hen, water hen, Fulica americana", "138": "bustard", "139": "ruddy turnstone, Arenaria interpres", "140": "red-backed sandpiper, dunlin, Erolia alpina", "141": "redshank, Tringa totanus", "142": "dowitcher", "143": "oystercatcher, oyster catcher", "144": "pelican", "145": "king penguin, Aptenodytes patagonica", "146": "albatross, mollymawk", "147": "grey whale, gray whale, devilfish, Eschrichtius gibbosus, Eschrichtius robustus", "148": "killer whale, killer, orca, grampus, sea wolf, Orcinus orca", "149": "dugong, Dugong dugon", "150": "sea lion", "151": "Chihuahua", "152": "Japanese spaniel", "153": "Maltese dog, Maltese terrier, Maltese", "154": "Pekinese, Pekingese, Peke", "155": "Shih-Tzu", "156": "Blenheim spaniel", "157": "papillon", "158": "toy terrier", "159": "Rhodesian ridgeback", "160": "Afghan hound, Afghan", "161": "basset, basset hound", "162": "beagle", "163": "bloodhound, sleuthhound", "164": "bluetick", "165": "black-and-tan coonhound", "166": "Walker hound, Walker foxhound", "167": "English foxhound", "168": "redbone", "169": "borzoi, Russian wolfhound", "170": "Irish wolfhound", "171": "Italian greyhound", "172": "whippet", "173": "Ibizan hound, Ibizan Podenco", "174": "Norwegian elkhound, elkhound", "175": "otterhound, otter hound", "176": "Saluki, gazelle hound", "177": "Scottish deerhound, deerhound", "178": "Weimaraner", "179": "Staffordshire bullterrier, Staffordshire bull terrier", "180": "American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier", "181": "Bedlington terrier", "182": "Border terrier", "183": "Kerry blue terrier", "184": "Irish terrier", "185": "Norfolk terrier", "186": "Norwich terrier", "187": "Yorkshire terrier", "188": "wire-haired fox terrier", "189": "Lakeland terrier", "190": "Sealyham terrier, Sealyham", "191": "Airedale, Airedale terrier", "192": "cairn, cairn terrier", "193": "Australian terrier", "194": "Dandie Dinmont, Dandie Dinmont terrier", "195": "Boston bull, Boston terrier", "196": "miniature schnauzer", "197": "giant schnauzer", "198": "standard schnauzer", "199": "Scotch terrier, Scottish terrier, Scottie", "200": "Tibetan terrier, chrysanthemum dog", "201": "silky terrier, Sydney silky", "202": "soft-coated wheaten terrier", "203": "West Highland white terrier", "204": "Lhasa, Lhasa apso", "205": "flat-coated retriever", "206": "curly-coated retriever", "207": "golden retriever", "208": "Labrador retriever", "209": "Chesapeake Bay retriever", "210": "German short-haired pointer", "211": "vizsla, Hungarian pointer", "212": "English setter", "213": "Irish setter, red setter", "214": "Gordon setter", "215": "Brittany spaniel", "216": "clumber, clumber spaniel", "217": "English springer, English springer spaniel", "218": "Welsh springer spaniel", "219": "cocker spaniel, English cocker spaniel, cocker", "220": "Sussex spaniel", "221": "Irish water spaniel", "222": "kuvasz", "223": "schipperke", "224": "groenendael", "225": "malinois", "226": "briard", "227": "kelpie", "228": "komondor", "229": "Old English sheepdog, bobtail", "230": "Shetland sheepdog, Shetland sheep dog, Shetland", "231": "collie", "232": "Border collie", "233": "Bouvier des Flandres, Bouviers des Flandres", "234": "Rottweiler", "235": "German shepherd, German shepherd dog, German police dog, alsatian", "236": "Doberman, Doberman pinscher", "237": "miniature pinscher", "238": "Greater Swiss Mountain dog", "239": "Bernese mountain dog", "240": "Appenzeller", "241": "EntleBucher", "242": "boxer", "243": "bull mastiff", "244": "Tibetan mastiff", "245": "French bulldog", "246": "Great Dane", "247": "Saint Bernard, St Bernard", "248": "Eskimo dog, husky", "249": "malamute, malemute, Alaskan malamute", "250": "Siberian husky", "251": "dalmatian, coach dog, carriage dog", "252": "affenpinscher, monkey pinscher, monkey dog", "253": "basenji", "254": "pug, pug-dog", "255": "Leonberg", "256": "Newfoundland, Newfoundland dog", "257": "Great Pyrenees", "258": "Samoyed, Samoyede", "259": "Pomeranian", "260": "chow, chow chow", "261": "keeshond", "262": "Brabancon griffon", "263": "Pembroke, Pembroke Welsh corgi", "264": "Cardigan, Cardigan Welsh corgi", "265": "toy poodle", "266": "miniature poodle", "267": "standard poodle", "268": "Mexican hairless", "269": "timber wolf, grey wolf, gray wolf, Canis lupus", "270": "white wolf, Arctic wolf, Canis lupus tundrarum", "271": "red wolf, maned wolf, Canis rufus, Canis niger", "272": "coyote, prairie wolf, brush wolf, Canis latrans", "273": "dingo, warrigal, warragal, Canis dingo", "274": "dhole, Cuon alpinus", "275": "African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus", "276": "hyena, hyaena", "277": "red fox, Vulpes vulpes", "278": "kit fox, Vulpes macrotis", "279": "Arctic fox, white fox, Alopex lagopus", "280": "grey fox, gray fox, Urocyon cinereoargenteus", "281": "tabby, tabby cat", "282": "tiger cat", "283": "Persian cat", "284": "Siamese cat, Siamese", "285": "Egyptian cat", "286": "cougar, puma, catamount, mountain lion, painter, panther, Felis concolor", "287": "lynx, catamount", "288": "leopard, Panthera pardus", "289": "snow leopard, ounce, Panthera uncia", "290": "jaguar, panther, Panthera onca, Felis onca", "291": "lion, king of beasts, Panthera leo", "292": "tiger, Panthera tigris", "293": "cheetah, chetah, Acinonyx jubatus", "294": "brown bear, bruin, Ursus arctos", "295": "American black bear, black bear, Ursus americanus, Euarctos americanus", "296": "ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus", "297": "sloth bear, Melursus ursinus, Ursus ursinus", "298": "mongoose", "299": "meerkat, mierkat", "300": "tiger beetle", "301": "ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle", "302": "ground beetle, carabid beetle", "303": "long-horned beetle, longicorn, longicorn beetle", "304": "leaf beetle, chrysomelid", "305": "dung beetle", "306": "rhinoceros beetle", "307": "weevil", "308": "fly", "309": "bee", "310": "ant, emmet, pismire", "311": "grasshopper, hopper", "312": "cricket", "313": "walking stick, walkingstick, stick insect", "314": "cockroach, roach", "315": "mantis, mantid", "316": "cicada, cicala", "317": "leafhopper", "318": "lacewing, lacewing fly", "319": "dragonfly, darning needle, devil's darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk", "320": "damselfly", "321": "admiral", "322": "ringlet, ringlet butterfly", "323": "monarch, monarch butterfly, milkweed butterfly, Danaus plexippus", "324": "cabbage butterfly", "325": "sulphur butterfly, sulfur butterfly", "326": "lycaenid, lycaenid butterfly", "327": "starfish, sea star", "328": "sea urchin", "329": "sea cucumber, holothurian", "330": "wood rabbit, cottontail, cottontail rabbit", "331": "hare", "332": "Angora, Angora rabbit", "333": "hamster", "334": "porcupine, hedgehog", "335": "fox squirrel, eastern fox squirrel, Sciurus niger", "336": "marmot", "337": "beaver", "338": "guinea pig, Cavia cobaya", "339": "sorrel", "340": "zebra", "341": "hog, pig, grunter, squealer, Sus scrofa", "342": "wild boar, boar, Sus scrofa", "343": "warthog", "344": "hippopotamus, hippo, river horse, Hippopotamus amphibius", "345": "ox", "346": "water buffalo, water ox, Asiatic buffalo, Bubalus bubalis", "347": "bison", "348": "ram, tup", "349": "bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis", "350": "ibex, Capra ibex", "351": "hartebeest", "352": "impala, Aepyceros melampus", "353": "gazelle", "354": "Arabian camel, dromedary, Camelus dromedarius", "355": "llama", "356": "weasel", "357": "mink", "358": "polecat, fitch, foulmart, foumart, Mustela putorius", "359": "black-footed ferret, ferret, Mustela nigripes", "360": "otter", "361": "skunk, polecat, wood pussy", "362": "badger", "363": "armadillo", "364": "three-toed sloth, ai, Bradypus tridactylus", "365": "orangutan, orang, orangutang, Pongo pygmaeus", "366": "gorilla, Gorilla gorilla", "367": "chimpanzee, chimp, Pan troglodytes", "368": "gibbon, Hylobates lar", "369": "siamang, Hylobates syndactylus, Symphalangus syndactylus", "370": "guenon, guenon monkey", "371": "patas, hussar monkey, Erythrocebus patas", "372": "baboon", "373": "macaque", "374": "langur", "375": "colobus, colobus monkey", "376": "proboscis monkey, Nasalis larvatus", "377": "marmoset", "378": "capuchin, ringtail, Cebus capucinus", "379": "howler monkey, howler", "380": "titi, titi monkey", "381": "spider monkey, Ateles geoffroyi", "382": "squirrel monkey, Saimiri sciureus", "383": "Madagascar cat, ring-tailed lemur, Lemur catta", "384": "indri, indris, Indri indri, Indri brevicaudatus", "385": "Indian elephant, Elephas maximus", "386": "African elephant, Loxodonta africana", "387": "lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens", "388": "giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca", "389": "barracouta, snoek", "390": "eel", "391": "coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus kisutch", "392": "rock beauty, Holocanthus tricolor", "393": "anemone fish", "394": "sturgeon", "395": "gar, garfish, garpike, billfish, Lepisosteus osseus", "396": "lionfish", "397": "puffer, pufferfish, blowfish, globefish", "398": "abacus", "399": "abaya", "400": "academic gown, academic robe, judge's robe", "401": "accordion, piano accordion, squeeze box", "402": "acoustic guitar", "403": "aircraft carrier, carrier, flattop, attack aircraft carrier", "404": "airliner", "405": "airship, dirigible", "406": "altar", "407": "ambulance", "408": "amphibian, amphibious vehicle", "409": "analog clock", "410": "apiary, bee house", "411": "apron", "412": "ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin", "413": "assault rifle, assault gun", "414": "backpack, back pack, knapsack, packsack, rucksack, haversack", "415": "bakery, bakeshop, bakehouse", "416": "balance beam, beam", "417": "balloon", "418": "ballpoint, ballpoint pen, ballpen, Biro", "419": "Band Aid", "420": "banjo", "421": "bannister, banister, balustrade, balusters, handrail", "422": "barbell", "423": "barber chair", "424": "barbershop", "425": "barn", "426": "barometer", "427": "barrel, cask", "428": "barrow, garden cart, lawn cart, wheelbarrow", "429": "baseball", "430": "basketball", "431": "bassinet", "432": "bassoon", "433": "bathing cap, swimming cap", "434": "bath towel", "435": "bathtub, bathing tub, bath, tub", "436": "beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon", "437": "beacon, lighthouse, beacon light, pharos", "438": "beaker", "439": "bearskin, busby, shako", "440": "beer bottle", "441": "beer glass", "442": "bell cote, bell cot", "443": "bib", "444": "bicycle-built-for-two, tandem bicycle, tandem", "445": "bikini, two-piece", "446": "binder, ring-binder", "447": "binoculars, field glasses, opera glasses", "448": "birdhouse", "449": "boathouse", "450": "bobsled, bobsleigh, bob", "451": "bolo tie, bolo, bola tie, bola", "452": "bonnet, poke bonnet", "453": "bookcase", "454": "bookshop, bookstore, bookstall", "455": "bottlecap", "456": "bow", "457": "bow tie, bow-tie, bowtie", "458": "brass, memorial tablet, plaque", "459": "brassiere, bra, bandeau", "460": "breakwater, groin, groyne, mole, bulwark, seawall, jetty", "461": "breastplate, aegis, egis", "462": "broom", "463": "bucket, pail", "464": "buckle", "465": "bulletproof vest", "466": "bullet train, bullet", "467": "butcher shop, meat market", "468": "cab, hack, taxi, taxicab", "469": "caldron, cauldron", "470": "candle, taper, wax light", "471": "cannon", "472": "canoe", "473": "can opener, tin opener", "474": "cardigan", "475": "car mirror", "476": "carousel, carrousel, merry-go-round, roundabout, whirligig", "477": "carpenter's kit, tool kit", "478": "carton", "479": "car wheel", "480": "cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM", "481": "cassette", "482": "cassette player", "483": "castle", "484": "catamaran", "485": "CD player", "486": "cello, violoncello", "487": "cellular telephone, cellular phone, cellphone, cell, mobile phone", "488": "chain", "489": "chainlink fence", "490": "chain mail, ring mail, mail, chain armor, chain armour, ring armor, ring armour", "491": "chain saw, chainsaw", "492": "chest", "493": "chiffonier, commode", "494": "chime, bell, gong", "495": "china cabinet, china closet", "496": "Christmas stocking", "497": "church, church building", "498": "cinema, movie theater, movie theatre, movie house, picture palace", "499": "cleaver, meat cleaver, chopper", "500": "cliff dwelling", "501": "cloak", "502": "clog, geta, patten, sabot", "503": "cocktail shaker", "504": "coffee mug", "505": "coffeepot", "506": "coil, spiral, volute, whorl, helix", "507": "combination lock", "508": "computer keyboard, keypad", "509": "confectionery, confectionary, candy store", "510": "container ship, containership, container vessel", "511": "convertible", "512": "corkscrew, bottle screw", "513": "cornet, horn, trumpet, trump", "514": "cowboy boot", "515": "cowboy hat, ten-gallon hat", "516": "cradle", "517": "crane", "518": "crash helmet", "519": "crate", "520": "crib, cot", "521": "Crock Pot", "522": "croquet ball", "523": "crutch", "524": "cuirass", "525": "dam, dike, dyke", "526": "desk", "527": "desktop computer", "528": "dial telephone, dial phone", "529": "diaper, nappy, napkin", "530": "digital clock", "531": "digital watch", "532": "dining table, board", "533": "dishrag, dishcloth", "534": "dishwasher, dish washer, dishwashing machine", "535": "disk brake, disc brake", "536": "dock, dockage, docking facility", "537": "dogsled, dog sled, dog sleigh", "538": "dome", "539": "doormat, welcome mat", "540": "drilling platform, offshore rig", "541": "drum, membranophone, tympan", "542": "drumstick", "543": "dumbbell", "544": "Dutch oven", "545": "electric fan, blower", "546": "electric guitar", "547": "electric locomotive", "548": "entertainment center", "549": "envelope", "550": "espresso maker", "551": "face powder", "552": "feather boa, boa", "553": "file, file cabinet, filing cabinet", "554": "fireboat", "555": "fire engine, fire truck", "556": "fire screen, fireguard", "557": "flagpole, flagstaff", "558": "flute, transverse flute", "559": "folding chair", "560": "football helmet", "561": "forklift", "562": "fountain", "563": "fountain pen", "564": "four-poster", "565": "freight car", "566": "French horn, horn", "567": "frying pan, frypan, skillet", "568": "fur coat", "569": "garbage truck, dustcart", "570": "gasmask, respirator, gas helmet", "571": "gas pump, gasoline pump, petrol pump, island dispenser", "572": "goblet", "573": "go-kart", "574": "golf ball", "575": "golfcart, golf cart", "576": "gondola", "577": "gong, tam-tam", "578": "gown", "579": "grand piano, grand", "580": "greenhouse, nursery, glasshouse", "581": "grille, radiator grille", "582": "grocery store, grocery, food market, market", "583": "guillotine", "584": "hair slide", "585": "hair spray", "586": "half track", "587": "hammer", "588": "hamper", "589": "hand blower, blow dryer, blow drier, hair dryer, hair drier", "590": "hand-held computer, hand-held microcomputer", "591": "handkerchief, hankie, hanky, hankey", "592": "hard disc, hard disk, fixed disk", "593": "harmonica, mouth organ, harp, mouth harp", "594": "harp", "595": "harvester, reaper", "596": "hatchet", "597": "holster", "598": "home theater, home theatre", "599": "honeycomb", "600": "hook, claw", "601": "hoopskirt, crinoline", "602": "horizontal bar, high bar", "603": "horse cart, horse-cart", "604": "hourglass", "605": "iPod", "606": "iron, smoothing iron", "607": "jack-o'-lantern", "608": "jean, blue jean, denim", "609": "jeep, landrover", "610": "jersey, T-shirt, tee shirt", "611": "jigsaw puzzle", "612": "jinrikisha, ricksha, rickshaw", "613": "joystick", "614": "kimono", "615": "knee pad", "616": "knot", "617": "lab coat, laboratory coat", "618": "ladle", "619": "lampshade, lamp shade", "620": "laptop, laptop computer", "621": "lawn mower, mower", "622": "lens cap, lens cover", "623": "letter opener, paper knife, paperknife", "624": "library", "625": "lifeboat", "626": "lighter, light, igniter, ignitor", "627": "limousine, limo", "628": "liner, ocean liner", "629": "lipstick, lip rouge", "630": "Loafer", "631": "lotion", "632": "loudspeaker, speaker, speaker unit, loudspeaker system, speaker system", "633": "loupe, jeweler's loupe", "634": "lumbermill, sawmill", "635": "magnetic compass", "636": "mailbag, postbag", "637": "mailbox, letter box", "638": "maillot", "639": "maillot, tank suit", "640": "manhole cover", "641": "maraca", "642": "marimba, xylophone", "643": "mask", "644": "matchstick", "645": "maypole", "646": "maze, labyrinth", "647": "measuring cup", "648": "medicine chest, medicine cabinet", "649": "megalith, megalithic structure", "650": "microphone, mike", "651": "microwave, microwave oven", "652": "military uniform", "653": "milk can", "654": "minibus", "655": "miniskirt, mini", "656": "minivan", "657": "missile", "658": "mitten", "659": "mixing bowl", "660": "mobile home, manufactured home", "661": "Model T", "662": "modem", "663": "monastery", "664": "monitor", "665": "moped", "666": "mortar", "667": "mortarboard", "668": "mosque", "669": "mosquito net", "670": "motor scooter, scooter", "671": "mountain bike, all-terrain bike, off-roader", "672": "mountain tent", "673": "mouse, computer mouse", "674": "mousetrap", "675": "moving van", "676": "muzzle", "677": "nail", "678": "neck brace", "679": "necklace", "680": "nipple", "681": "notebook, notebook computer", "682": "obelisk", "683": "oboe, hautboy, hautbois", "684": "ocarina, sweet potato", "685": "odometer, hodometer, mileometer, milometer", "686": "oil filter", "687": "organ, pipe organ", "688": "oscilloscope, scope, cathode-ray oscilloscope, CRO", "689": "overskirt", "690": "oxcart", "691": "oxygen mask", "692": "packet", "693": "paddle, boat paddle", "694": "paddlewheel, paddle wheel", "695": "padlock", "696": "paintbrush", "697": "pajama, pyjama, pj's, jammies", "698": "palace", "699": "panpipe, pandean pipe, syrinx", "700": "paper towel", "701": "parachute, chute", "702": "parallel bars, bars", "703": "park bench", "704": "parking meter", "705": "passenger car, coach, carriage", "706": "patio, terrace", "707": "pay-phone, pay-station", "708": "pedestal, plinth, footstall", "709": "pencil box, pencil case", "710": "pencil sharpener", "711": "perfume, essence", "712": "Petri dish", "713": "photocopier", "714": "pick, plectrum, plectron", "715": "pickelhaube", "716": "picket fence, paling", "717": "pickup, pickup truck", "718": "pier", "719": "piggy bank, penny bank", "720": "pill bottle", "721": "pillow", "722": "ping-pong ball", "723": "pinwheel", "724": "pirate, pirate ship", "725": "pitcher, ewer", "726": "plane, carpenter's plane, woodworking plane", "727": "planetarium", "728": "plastic bag", "729": "plate rack", "730": "plow, plough", "731": "plunger, plumber's helper", "732": "Polaroid camera, Polaroid Land camera", "733": "pole", "734": "police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria", "735": "poncho", "736": "pool table, billiard table, snooker table", "737": "pop bottle, soda bottle", "738": "pot, flowerpot", "739": "potter's wheel", "740": "power drill", "741": "prayer rug, prayer mat", "742": "printer", "743": "prison, prison house", "744": "projectile, missile", "745": "projector", "746": "puck, hockey puck", "747": "punching bag, punch bag, punching ball, punchball", "748": "purse", "749": "quill, quill pen", "750": "quilt, comforter, comfort, puff", "751": "racer, race car, racing car", "752": "racket, racquet", "753": "radiator", "754": "radio, wireless", "755": "radio telescope, radio reflector", "756": "rain barrel", "757": "recreational vehicle, RV, R.V.", "758": "reel", "759": "reflex camera", "760": "refrigerator, icebox", "761": "remote control, remote", "762": "restaurant, eating house, eating place, eatery", "763": "revolver, six-gun, six-shooter", "764": "rifle", "765": "rocking chair, rocker", "766": "rotisserie", "767": "rubber eraser, rubber, pencil eraser", "768": "rugby ball", "769": "rule, ruler", "770": "running shoe", "771": "safe", "772": "safety pin", "773": "saltshaker, salt shaker", "774": "sandal", "775": "sarong", "776": "sax, saxophone", "777": "scabbard", "778": "scale, weighing machine", "779": "school bus", "780": "schooner", "781": "scoreboard", "782": "screen, CRT screen", "783": "screw", "784": "screwdriver", "785": "seat belt, seatbelt", "786": "sewing machine", "787": "shield, buckler", "788": "shoe shop, shoe-shop, shoe store", "789": "shoji", "790": "shopping basket", "791": "shopping cart", "792": "shovel", "793": "shower cap", "794": "shower curtain", "795": "ski", "796": "ski mask", "797": "sleeping bag", "798": "slide rule, slipstick", "799": "sliding door", "800": "slot, one-armed bandit", "801": "snorkel", "802": "snowmobile", "803": "snowplow, snowplough", "804": "soap dispenser", "805": "soccer ball", "806": "sock", "807": "solar dish, solar collector, solar furnace", "808": "sombrero", "809": "soup bowl", "810": "space bar", "811": "space heater", "812": "space shuttle", "813": "spatula", "814": "speedboat", "815": "spider web, spider's web", "816": "spindle", "817": "sports car, sport car", "818": "spotlight, spot", "819": "stage", "820": "steam locomotive", "821": "steel arch bridge", "822": "steel drum", "823": "stethoscope", "824": "stole", "825": "stone wall", "826": "stopwatch, stop watch", "827": "stove", "828": "strainer", "829": "streetcar, tram, tramcar, trolley, trolley car", "830": "stretcher", "831": "studio couch, day bed", "832": "stupa, tope", "833": "submarine, pigboat, sub, U-boat", "834": "suit, suit of clothes", "835": "sundial", "836": "sunglass", "837": "sunglasses, dark glasses, shades", "838": "sunscreen, sunblock, sun blocker", "839": "suspension bridge", "840": "swab, swob, mop", "841": "sweatshirt", "842": "swimming trunks, bathing trunks", "843": "swing", "844": "switch, electric switch, electrical switch", "845": "syringe", "846": "table lamp", "847": "tank, army tank, armored combat vehicle, armoured combat vehicle", "848": "tape player", "849": "teapot", "850": "teddy, teddy bear", "851": "television, television system", "852": "tennis ball", "853": "thatch, thatched roof", "854": "theater curtain, theatre curtain", "855": "thimble", "856": "thresher, thrasher, threshing machine", "857": "throne", "858": "tile roof", "859": "toaster", "860": "tobacco shop, tobacconist shop, tobacconist", "861": "toilet seat", "862": "torch", "863": "totem pole", "864": "tow truck, tow car, wrecker", "865": "toyshop", "866": "tractor", "867": "trailer truck, tractor trailer, trucking rig, rig, articulated lorry, semi", "868": "tray", "869": "trench coat", "870": "tricycle, trike, velocipede", "871": "trimaran", "872": "tripod", "873": "triumphal arch", "874": "trolleybus, trolley coach, trackless trolley", "875": "trombone", "876": "tub, vat", "877": "turnstile", "878": "typewriter keyboard", "879": "umbrella", "880": "unicycle, monocycle", "881": "upright, upright piano", "882": "vacuum, vacuum cleaner", "883": "vase", "884": "vault", "885": "velvet", "886": "vending machine", "887": "vestment", "888": "viaduct", "889": "violin, fiddle", "890": "volleyball", "891": "waffle iron", "892": "wall clock", "893": "wallet, billfold, notecase, pocketbook", "894": "wardrobe, closet, press", "895": "warplane, military plane", "896": "washbasin, handbasin, washbowl, lavabo, wash-hand basin", "897": "washer, automatic washer, washing machine", "898": "water bottle", "899": "water jug", "900": "water tower", "901": "whiskey jug", "902": "whistle", "903": "wig", "904": "window screen", "905": "window shade", "906": "Windsor tie", "907": "wine bottle", "908": "wing", "909": "wok", "910": "wooden spoon", "911": "wool, woolen, woollen", "912": "worm fence, snake fence, snake-rail fence, Virginia fence", "913": "wreck", "914": "yawl", "915": "yurt", "916": "web site, website, internet site, site", "917": "comic book", "918": "crossword puzzle, crossword", "919": "street sign", "920": "traffic light, traffic signal, stoplight", "921": "book jacket, dust cover, dust jacket, dust wrapper", "922": "menu", "923": "plate", "924": "guacamole", "925": "consomme", "926": "hot pot, hotpot", "927": "trifle", "928": "ice cream, icecream", "929": "ice lolly, lolly, lollipop, popsicle", "930": "French loaf", "931": "bagel, beigel", "932": "pretzel", "933": "cheeseburger", "934": "hotdog, hot dog, red hot", "935": "mashed potato", "936": "head cabbage", "937": "broccoli", "938": "cauliflower", "939": "zucchini, courgette", "940": "spaghetti squash", "941": "acorn squash", "942": "butternut squash", "943": "cucumber, cuke", "944": "artichoke, globe artichoke", "945": "bell pepper", "946": "cardoon", "947": "mushroom", "948": "Granny Smith", "949": "strawberry", "950": "orange", "951": "lemon", "952": "fig", "953": "pineapple, ananas", "954": "banana", "955": "jackfruit, jak, jack", "956": "custard apple", "957": "pomegranate", "958": "hay", "959": "carbonara", "960": "chocolate sauce, chocolate syrup", "961": "dough", "962": "meat loaf, meatloaf", "963": "pizza, pizza pie", "964": "potpie", "965": "burrito", "966": "red wine", "967": "espresso", "968": "cup", "969": "eggnog", "970": "alp", "971": "bubble", "972": "cliff, drop, drop-off", "973": "coral reef", "974": "geyser", "975": "lakeside, lakeshore", "976": "promontory, headland, head, foreland", "977": "sandbar, sand bar", "978": "seashore, coast, seacoast, sea-coast", "979": "valley, vale", "980": "volcano", "981": "ballplayer, baseball player", "982": "groom, bridegroom", "983": "scuba diver", "984": "rapeseed", "985": "daisy", "986": "yellow lady's slipper, yellow lady-slipper, Cypripedium calceolus, Cypripedium parviflorum", "987": "corn", "988": "acorn", "989": "hip, rose hip, rosehip", "990": "buckeye, horse chestnut, conker", "991": "coral fungus", "992": "agaric", "993": "gyromitra", "994": "stinkhorn, carrion fungus", "995": "earthstar", "996": "hen-of-the-woods, hen of the woods, Polyporus frondosus, Grifola frondosa", "997": "bolete", "998": "ear, spike, capitulum", "999": "toilet tissue, toilet paper, bathroom tissue"} \ No newline at end of file diff --git a/imagenet_1000.txt b/imagenet_1000.txt new file mode 100644 index 0000000..376e180 --- /dev/null +++ b/imagenet_1000.txt @@ -0,0 +1,1000 @@ +0 tench, Tinca tinca +1 goldfish, Carassius auratus +2 great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias +3 tiger shark, Galeocerdo cuvieri +4 hammerhead, hammerhead shark +5 electric ray, crampfish, numbfish, torpedo +6 stingray +7 cock +8 hen +9 ostrich, Struthio camelus +10 brambling, Fringilla montifringilla +11 goldfinch, Carduelis carduelis +12 house finch, linnet, Carpodacus mexicanus +13 junco, snowbird +14 indigo bunting, indigo finch, indigo bird, Passerina cyanea +15 robin, American robin, Turdus migratorius +16 bulbul +17 jay +18 magpie +19 chickadee +20 water ouzel, dipper +21 kite +22 bald eagle, American eagle, Haliaeetus leucocephalus +23 vulture +24 great grey owl, great gray owl, Strix nebulosa +25 European fire salamander, Salamandra salamandra +26 common newt, Triturus vulgaris +27 eft +28 spotted salamander, Ambystoma maculatum +29 axolotl, mud puppy, Ambystoma mexicanum +30 bullfrog, Rana catesbeiana +31 tree frog, tree-frog +32 tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui +33 loggerhead, loggerhead turtle, Caretta caretta +34 leatherback turtle, leatherback, leathery turtle, Dermochelys coriacea +35 mud turtle +36 terrapin +37 box turtle, box tortoise +38 banded gecko +39 common iguana, iguana, Iguana iguana +40 American chameleon, anole, Anolis carolinensis +41 whiptail, whiptail lizard +42 agama +43 frilled lizard, Chlamydosaurus kingi +44 alligator lizard +45 Gila monster, Heloderma suspectum +46 green lizard, Lacerta viridis +47 African chameleon, Chamaeleo chamaeleon +48 Komodo dragon, Komodo lizard, dragon lizard, giant lizard, Varanus komodoensis +49 African crocodile, Nile crocodile, Crocodylus niloticus +50 American alligator, Alligator mississipiensis +51 triceratops +52 thunder snake, worm snake, Carphophis amoenus +53 ringneck snake, ring-necked snake, ring snake +54 hognose snake, puff adder, sand viper +55 green snake, grass snake +56 king snake, kingsnake +57 garter snake, grass snake +58 water snake +59 vine snake +60 night snake, Hypsiglena torquata +61 boa constrictor, Constrictor constrictor +62 rock python, rock snake, Python sebae +63 Indian cobra, Naja naja +64 green mamba +65 sea snake +66 horned viper, cerastes, sand viper, horned asp, Cerastes cornutus +67 diamondback, diamondback rattlesnake, Crotalus adamanteus +68 sidewinder, horned rattlesnake, Crotalus cerastes +69 trilobite +70 harvestman, daddy longlegs, Phalangium opilio +71 scorpion +72 black and gold garden spider, Argiope aurantia +73 barn spider, Araneus cavaticus +74 garden spider, Aranea diademata +75 black widow, Latrodectus mactans +76 tarantula +77 wolf spider, hunting spider +78 tick +79 centipede +80 black grouse +81 ptarmigan +82 ruffed grouse, partridge, Bonasa umbellus +83 prairie chicken, prairie grouse, prairie fowl +84 peacock +85 quail +86 partridge +87 African grey, African gray, Psittacus erithacus +88 macaw +89 sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita +90 lorikeet +91 coucal +92 bee eater +93 hornbill +94 hummingbird +95 jacamar +96 toucan +97 drake +98 red-breasted merganser, Mergus serrator +99 goose +100 black swan, Cygnus atratus +101 tusker +102 echidna, spiny anteater, anteater +103 platypus, duckbill, duckbilled platypus, duck-billed platypus, Ornithorhynchus anatinus +104 wallaby, brush kangaroo +105 koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus +106 wombat +107 jellyfish +108 sea anemone, anemone +109 brain coral +110 flatworm, platyhelminth +111 nematode, nematode worm, roundworm +112 conch +113 snail +114 slug +115 sea slug, nudibranch +116 chiton, coat-of-mail shell, sea cradle, polyplacophore +117 chambered nautilus, pearly nautilus, nautilus +118 Dungeness crab, Cancer magister +119 rock crab, Cancer irroratus +120 fiddler crab +121 king crab, Alaska crab, Alaskan king crab, Alaska king crab, Paralithodes camtschatica +122 American lobster, Northern lobster, Maine lobster, Homarus americanus +123 spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish +124 crayfish, crawfish, crawdad, crawdaddy +125 hermit crab +126 isopod +127 white stork, Ciconia ciconia +128 black stork, Ciconia nigra +129 spoonbill +130 flamingo +131 little blue heron, Egretta caerulea +132 American egret, great white heron, Egretta albus +133 bittern +134 crane +135 limpkin, Aramus pictus +136 European gallinule, Porphyrio porphyrio +137 American coot, marsh hen, mud hen, water hen, Fulica americana +138 bustard +139 ruddy turnstone, Arenaria interpres +140 red-backed sandpiper, dunlin, Erolia alpina +141 redshank, Tringa totanus +142 dowitcher +143 oystercatcher, oyster catcher +144 pelican +145 king penguin, Aptenodytes patagonica +146 albatross, mollymawk +147 grey whale, gray whale, devilfish, Eschrichtius gibbosus, Eschrichtius robustus +148 killer whale, killer, orca, grampus, sea wolf, Orcinus orca +149 dugong, Dugong dugon +150 sea lion +151 Chihuahua +152 Japanese spaniel +153 Maltese dog, Maltese terrier, Maltese +154 Pekinese, Pekingese, Peke +155 Shih-Tzu +156 Blenheim spaniel +157 papillon +158 toy terrier +159 Rhodesian ridgeback +160 Afghan hound, Afghan +161 basset, basset hound +162 beagle +163 bloodhound, sleuthhound +164 bluetick +165 black-and-tan coonhound +166 Walker hound, Walker foxhound +167 English foxhound +168 redbone +169 borzoi, Russian wolfhound +170 Irish wolfhound +171 Italian greyhound +172 whippet +173 Ibizan hound, Ibizan Podenco +174 Norwegian elkhound, elkhound +175 otterhound, otter hound +176 Saluki, gazelle hound +177 Scottish deerhound, deerhound +178 Weimaraner +179 Staffordshire bullterrier, Staffordshire bull terrier +180 American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier +181 Bedlington terrier +182 Border terrier +183 Kerry blue terrier +184 Irish terrier +185 Norfolk terrier +186 Norwich terrier +187 Yorkshire terrier +188 wire-haired fox terrier +189 Lakeland terrier +190 Sealyham terrier, Sealyham +191 Airedale, Airedale terrier +192 cairn, cairn terrier +193 Australian terrier +194 Dandie Dinmont, Dandie Dinmont terrier +195 Boston bull, Boston terrier +196 miniature schnauzer +197 giant schnauzer +198 standard schnauzer +199 Scotch terrier, Scottish terrier, Scottie +200 Tibetan terrier, chrysanthemum dog +201 silky terrier, Sydney silky +202 soft-coated wheaten terrier +203 West Highland white terrier +204 Lhasa, Lhasa apso +205 flat-coated retriever +206 curly-coated retriever +207 golden retriever +208 Labrador retriever +209 Chesapeake Bay retriever +210 German short-haired pointer +211 vizsla, Hungarian pointer +212 English setter +213 Irish setter, red setter +214 Gordon setter +215 Brittany spaniel +216 clumber, clumber spaniel +217 English springer, English springer spaniel +218 Welsh springer spaniel +219 cocker spaniel, English cocker spaniel, cocker +220 Sussex spaniel +221 Irish water spaniel +222 kuvasz +223 schipperke +224 groenendael +225 malinois +226 briard +227 kelpie +228 komondor +229 Old English sheepdog, bobtail +230 Shetland sheepdog, Shetland sheep dog, Shetland +231 collie +232 Border collie +233 Bouvier des Flandres, Bouviers des Flandres +234 Rottweiler +235 German shepherd, German shepherd dog, German police dog, alsatian +236 Doberman, Doberman pinscher +237 miniature pinscher +238 Greater Swiss Mountain dog +239 Bernese mountain dog +240 Appenzeller +241 EntleBucher +242 boxer +243 bull mastiff +244 Tibetan mastiff +245 French bulldog +246 Great Dane +247 Saint Bernard, St Bernard +248 Eskimo dog, husky +249 malamute, malemute, Alaskan malamute +250 Siberian husky +251 dalmatian, coach dog, carriage dog +252 affenpinscher, monkey pinscher, monkey dog +253 basenji +254 pug, pug-dog +255 Leonberg +256 Newfoundland, Newfoundland dog +257 Great Pyrenees +258 Samoyed, Samoyede +259 Pomeranian +260 chow, chow chow +261 keeshond +262 Brabancon griffon +263 Pembroke, Pembroke Welsh corgi +264 Cardigan, Cardigan Welsh corgi +265 toy poodle +266 miniature poodle +267 standard poodle +268 Mexican hairless +269 timber wolf, grey wolf, gray wolf, Canis lupus +270 white wolf, Arctic wolf, Canis lupus tundrarum +271 red wolf, maned wolf, Canis rufus, Canis niger +272 coyote, prairie wolf, brush wolf, Canis latrans +273 dingo, warrigal, warragal, Canis dingo +274 dhole, Cuon alpinus +275 African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus +276 hyena, hyaena +277 red fox, Vulpes vulpes +278 kit fox, Vulpes macrotis +279 Arctic fox, white fox, Alopex lagopus +280 grey fox, gray fox, Urocyon cinereoargenteus +281 tabby, tabby cat +282 tiger cat +283 Persian cat +284 Siamese cat, Siamese +285 Egyptian cat +286 cougar, puma, catamount, mountain lion, painter, panther, Felis concolor +287 lynx, catamount +288 leopard, Panthera pardus +289 snow leopard, ounce, Panthera uncia +290 jaguar, panther, Panthera onca, Felis onca +291 lion, king of beasts, Panthera leo +292 tiger, Panthera tigris +293 cheetah, chetah, Acinonyx jubatus +294 brown bear, bruin, Ursus arctos +295 American black bear, black bear, Ursus americanus, Euarctos americanus +296 ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus +297 sloth bear, Melursus ursinus, Ursus ursinus +298 mongoose +299 meerkat, mierkat +300 tiger beetle +301 ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle +302 ground beetle, carabid beetle +303 long-horned beetle, longicorn, longicorn beetle +304 leaf beetle, chrysomelid +305 dung beetle +306 rhinoceros beetle +307 weevil +308 fly +309 bee +310 ant, emmet, pismire +311 grasshopper, hopper +312 cricket +313 walking stick, walkingstick, stick insect +314 cockroach, roach +315 mantis, mantid +316 cicada, cicala +317 leafhopper +318 lacewing, lacewing fly +319 dragonfly, darning needle, devil's darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk +320 damselfly +321 admiral +322 ringlet, ringlet butterfly +323 monarch, monarch butterfly, milkweed butterfly, Danaus plexippus +324 cabbage butterfly +325 sulphur butterfly, sulfur butterfly +326 lycaenid, lycaenid butterfly +327 starfish, sea star +328 sea urchin +329 sea cucumber, holothurian +330 wood rabbit, cottontail, cottontail rabbit +331 hare +332 Angora, Angora rabbit +333 hamster +334 porcupine, hedgehog +335 fox squirrel, eastern fox squirrel, Sciurus niger +336 marmot +337 beaver +338 guinea pig, Cavia cobaya +339 sorrel +340 zebra +341 hog, pig, grunter, squealer, Sus scrofa +342 wild boar, boar, Sus scrofa +343 warthog +344 hippopotamus, hippo, river horse, Hippopotamus amphibius +345 ox +346 water buffalo, water ox, Asiatic buffalo, Bubalus bubalis +347 bison +348 ram, tup +349 bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis +350 ibex, Capra ibex +351 hartebeest +352 impala, Aepyceros melampus +353 gazelle +354 Arabian camel, dromedary, Camelus dromedarius +355 llama +356 weasel +357 mink +358 polecat, fitch, foulmart, foumart, Mustela putorius +359 black-footed ferret, ferret, Mustela nigripes +360 otter +361 skunk, polecat, wood pussy +362 badger +363 armadillo +364 three-toed sloth, ai, Bradypus tridactylus +365 orangutan, orang, orangutang, Pongo pygmaeus +366 gorilla, Gorilla gorilla +367 chimpanzee, chimp, Pan troglodytes +368 gibbon, Hylobates lar +369 siamang, Hylobates syndactylus, Symphalangus syndactylus +370 guenon, guenon monkey +371 patas, hussar monkey, Erythrocebus patas +372 baboon +373 macaque +374 langur +375 colobus, colobus monkey +376 proboscis monkey, Nasalis larvatus +377 marmoset +378 capuchin, ringtail, Cebus capucinus +379 howler monkey, howler +380 titi, titi monkey +381 spider monkey, Ateles geoffroyi +382 squirrel monkey, Saimiri sciureus +383 Madagascar cat, ring-tailed lemur, Lemur catta +384 indri, indris, Indri indri, Indri brevicaudatus +385 Indian elephant, Elephas maximus +386 African elephant, Loxodonta africana +387 lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens +388 giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca +389 barracouta, snoek +390 eel +391 coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus kisutch +392 rock beauty, Holocanthus tricolor +393 anemone fish +394 sturgeon +395 gar, garfish, garpike, billfish, Lepisosteus osseus +396 lionfish +397 puffer, pufferfish, blowfish, globefish +398 abacus +399 abaya +400 academic gown, academic robe, judge's robe +401 accordion, piano accordion, squeeze box +402 acoustic guitar +403 aircraft carrier, carrier, flattop, attack aircraft carrier +404 airliner +405 airship, dirigible +406 altar +407 ambulance +408 amphibian, amphibious vehicle +409 analog clock +410 apiary, bee house +411 apron +412 ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin +413 assault rifle, assault gun +414 backpack, back pack, knapsack, packsack, rucksack, haversack +415 bakery, bakeshop, bakehouse +416 balance beam, beam +417 balloon +418 ballpoint, ballpoint pen, ballpen, Biro +419 Band Aid +420 banjo +421 bannister, banister, balustrade, balusters, handrail +422 barbell +423 barber chair +424 barbershop +425 barn +426 barometer +427 barrel, cask +428 barrow, garden cart, lawn cart, wheelbarrow +429 baseball +430 basketball +431 bassinet +432 bassoon +433 bathing cap, swimming cap +434 bath towel +435 bathtub, bathing tub, bath, tub +436 beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon +437 beacon, lighthouse, beacon light, pharos +438 beaker +439 bearskin, busby, shako +440 beer bottle +441 beer glass +442 bell cote, bell cot +443 bib +444 bicycle-built-for-two, tandem bicycle, tandem +445 bikini, two-piece +446 binder, ring-binder +447 binoculars, field glasses, opera glasses +448 birdhouse +449 boathouse +450 bobsled, bobsleigh, bob +451 bolo tie, bolo, bola tie, bola +452 bonnet, poke bonnet +453 bookcase +454 bookshop, bookstore, bookstall +455 bottlecap +456 bow +457 bow tie, bow-tie, bowtie +458 brass, memorial tablet, plaque +459 brassiere, bra, bandeau +460 breakwater, groin, groyne, mole, bulwark, seawall, jetty +461 breastplate, aegis, egis +462 broom +463 bucket, pail +464 buckle +465 bulletproof vest +466 bullet train, bullet +467 butcher shop, meat market +468 cab, hack, taxi, taxicab +469 caldron, cauldron +470 candle, taper, wax light +471 cannon +472 canoe +473 can opener, tin opener +474 cardigan +475 car mirror +476 carousel, carrousel, merry-go-round, roundabout, whirligig +477 carpenter's kit, tool kit +478 carton +479 car wheel +480 cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM +481 cassette +482 cassette player +483 castle +484 catamaran +485 CD player +486 cello, violoncello +487 cellular telephone, cellular phone, cellphone, cell, mobile phone +488 chain +489 chainlink fence +490 chain mail, ring mail, mail, chain armor, chain armour, ring armor, ring armour +491 chain saw, chainsaw +492 chest +493 chiffonier, commode +494 chime, bell, gong +495 china cabinet, china closet +496 Christmas stocking +497 church, church building +498 cinema, movie theater, movie theatre, movie house, picture palace +499 cleaver, meat cleaver, chopper +500 cliff dwelling +501 cloak +502 clog, geta, patten, sabot +503 cocktail shaker +504 coffee mug +505 coffeepot +506 coil, spiral, volute, whorl, helix +507 combination lock +508 computer keyboard, keypad +509 confectionery, confectionary, candy store +510 container ship, containership, container vessel +511 convertible +512 corkscrew, bottle screw +513 cornet, horn, trumpet, trump +514 cowboy boot +515 cowboy hat, ten-gallon hat +516 cradle +517 crane +518 crash helmet +519 crate +520 crib, cot +521 Crock Pot +522 croquet ball +523 crutch +524 cuirass +525 dam, dike, dyke +526 desk +527 desktop computer +528 dial telephone, dial phone +529 diaper, nappy, napkin +530 digital clock +531 digital watch +532 dining table, board +533 dishrag, dishcloth +534 dishwasher, dish washer, dishwashing machine +535 disk brake, disc brake +536 dock, dockage, docking facility +537 dogsled, dog sled, dog sleigh +538 dome +539 doormat, welcome mat +540 drilling platform, offshore rig +541 drum, membranophone, tympan +542 drumstick +543 dumbbell +544 Dutch oven +545 electric fan, blower +546 electric guitar +547 electric locomotive +548 entertainment center +549 envelope +550 espresso maker +551 face powder +552 feather boa, boa +553 file, file cabinet, filing cabinet +554 fireboat +555 fire engine, fire truck +556 fire screen, fireguard +557 flagpole, flagstaff +558 flute, transverse flute +559 folding chair +560 football helmet +561 forklift +562 fountain +563 fountain pen +564 four-poster +565 freight car +566 French horn, horn +567 frying pan, frypan, skillet +568 fur coat +569 garbage truck, dustcart +570 gasmask, respirator, gas helmet +571 gas pump, gasoline pump, petrol pump, island dispenser +572 goblet +573 go-kart +574 golf ball +575 golfcart, golf cart +576 gondola +577 gong, tam-tam +578 gown +579 grand piano, grand +580 greenhouse, nursery, glasshouse +581 grille, radiator grille +582 grocery store, grocery, food market, market +583 guillotine +584 hair slide +585 hair spray +586 half track +587 hammer +588 hamper +589 hand blower, blow dryer, blow drier, hair dryer, hair drier +590 hand-held computer, hand-held microcomputer +591 handkerchief, hankie, hanky, hankey +592 hard disc, hard disk, fixed disk +593 harmonica, mouth organ, harp, mouth harp +594 harp +595 harvester, reaper +596 hatchet +597 holster +598 home theater, home theatre +599 honeycomb +600 hook, claw +601 hoopskirt, crinoline +602 horizontal bar, high bar +603 horse cart, horse-cart +604 hourglass +605 iPod +606 iron, smoothing iron +607 jack-o'-lantern +608 jean, blue jean, denim +609 jeep, landrover +610 jersey, T-shirt, tee shirt +611 jigsaw puzzle +612 jinrikisha, ricksha, rickshaw +613 joystick +614 kimono +615 knee pad +616 knot +617 lab coat, laboratory coat +618 ladle +619 lampshade, lamp shade +620 laptop, laptop computer +621 lawn mower, mower +622 lens cap, lens cover +623 letter opener, paper knife, paperknife +624 library +625 lifeboat +626 lighter, light, igniter, ignitor +627 limousine, limo +628 liner, ocean liner +629 lipstick, lip rouge +630 Loafer +631 lotion +632 loudspeaker, speaker, speaker unit, loudspeaker system, speaker system +633 loupe, jeweler's loupe +634 lumbermill, sawmill +635 magnetic compass +636 mailbag, postbag +637 mailbox, letter box +638 maillot +639 maillot, tank suit +640 manhole cover +641 maraca +642 marimba, xylophone +643 mask +644 matchstick +645 maypole +646 maze, labyrinth +647 measuring cup +648 medicine chest, medicine cabinet +649 megalith, megalithic structure +650 microphone, mike +651 microwave, microwave oven +652 military uniform +653 milk can +654 minibus +655 miniskirt, mini +656 minivan +657 missile +658 mitten +659 mixing bowl +660 mobile home, manufactured home +661 Model T +662 modem +663 monastery +664 monitor +665 moped +666 mortar +667 mortarboard +668 mosque +669 mosquito net +670 motor scooter, scooter +671 mountain bike, all-terrain bike, off-roader +672 mountain tent +673 mouse, computer mouse +674 mousetrap +675 moving van +676 muzzle +677 nail +678 neck brace +679 necklace +680 nipple +681 notebook, notebook computer +682 obelisk +683 oboe, hautboy, hautbois +684 ocarina, sweet potato +685 odometer, hodometer, mileometer, milometer +686 oil filter +687 organ, pipe organ +688 oscilloscope, scope, cathode-ray oscilloscope, CRO +689 overskirt +690 oxcart +691 oxygen mask +692 packet +693 paddle, boat paddle +694 paddlewheel, paddle wheel +695 padlock +696 paintbrush +697 pajama, pyjama, pj's, jammies +698 palace +699 panpipe, pandean pipe, syrinx +700 paper towel +701 parachute, chute +702 parallel bars, bars +703 park bench +704 parking meter +705 passenger car, coach, carriage +706 patio, terrace +707 pay-phone, pay-station +708 pedestal, plinth, footstall +709 pencil box, pencil case +710 pencil sharpener +711 perfume, essence +712 Petri dish +713 photocopier +714 pick, plectrum, plectron +715 pickelhaube +716 picket fence, paling +717 pickup, pickup truck +718 pier +719 piggy bank, penny bank +720 pill bottle +721 pillow +722 ping-pong ball +723 pinwheel +724 pirate, pirate ship +725 pitcher, ewer +726 plane, carpenter's plane, woodworking plane +727 planetarium +728 plastic bag +729 plate rack +730 plow, plough +731 plunger, plumber's helper +732 Polaroid camera, Polaroid Land camera +733 pole +734 police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria +735 poncho +736 pool table, billiard table, snooker table +737 pop bottle, soda bottle +738 pot, flowerpot +739 potter's wheel +740 power drill +741 prayer rug, prayer mat +742 printer +743 prison, prison house +744 projectile, missile +745 projector +746 puck, hockey puck +747 punching bag, punch bag, punching ball, punchball +748 purse +749 quill, quill pen +750 quilt, comforter, comfort, puff +751 racer, race car, racing car +752 racket, racquet +753 radiator +754 radio, wireless +755 radio telescope, radio reflector +756 rain barrel +757 recreational vehicle, RV, R.V. +758 reel +759 reflex camera +760 refrigerator, icebox +761 remote control, remote +762 restaurant, eating house, eating place, eatery +763 revolver, six-gun, six-shooter +764 rifle +765 rocking chair, rocker +766 rotisserie +767 rubber eraser, rubber, pencil eraser +768 rugby ball +769 rule, ruler +770 running shoe +771 safe +772 safety pin +773 saltshaker, salt shaker +774 sandal +775 sarong +776 sax, saxophone +777 scabbard +778 scale, weighing machine +779 school bus +780 schooner +781 scoreboard +782 screen, CRT screen +783 screw +784 screwdriver +785 seat belt, seatbelt +786 sewing machine +787 shield, buckler +788 shoe shop, shoe-shop, shoe store +789 shoji +790 shopping basket +791 shopping cart +792 shovel +793 shower cap +794 shower curtain +795 ski +796 ski mask +797 sleeping bag +798 slide rule, slipstick +799 sliding door +800 slot, one-armed bandit +801 snorkel +802 snowmobile +803 snowplow, snowplough +804 soap dispenser +805 soccer ball +806 sock +807 solar dish, solar collector, solar furnace +808 sombrero +809 soup bowl +810 space bar +811 space heater +812 space shuttle +813 spatula +814 speedboat +815 spider web, spider's web +816 spindle +817 sports car, sport car +818 spotlight, spot +819 stage +820 steam locomotive +821 steel arch bridge +822 steel drum +823 stethoscope +824 stole +825 stone wall +826 stopwatch, stop watch +827 stove +828 strainer +829 streetcar, tram, tramcar, trolley, trolley car +830 stretcher +831 studio couch, day bed +832 stupa, tope +833 submarine, pigboat, sub, U-boat +834 suit, suit of clothes +835 sundial +836 sunglass +837 sunglasses, dark glasses, shades +838 sunscreen, sunblock, sun blocker +839 suspension bridge +840 swab, swob, mop +841 sweatshirt +842 swimming trunks, bathing trunks +843 swing +844 switch, electric switch, electrical switch +845 syringe +846 table lamp +847 tank, army tank, armored combat vehicle, armoured combat vehicle +848 tape player +849 teapot +850 teddy, teddy bear +851 television, television system +852 tennis ball +853 thatch, thatched roof +854 theater curtain, theatre curtain +855 thimble +856 thresher, thrasher, threshing machine +857 throne +858 tile roof +859 toaster +860 tobacco shop, tobacconist shop, tobacconist +861 toilet seat +862 torch +863 totem pole +864 tow truck, tow car, wrecker +865 toyshop +866 tractor +867 trailer truck, tractor trailer, trucking rig, rig, articulated lorry, semi +868 tray +869 trench coat +870 tricycle, trike, velocipede +871 trimaran +872 tripod +873 triumphal arch +874 trolleybus, trolley coach, trackless trolley +875 trombone +876 tub, vat +877 turnstile +878 typewriter keyboard +879 umbrella +880 unicycle, monocycle +881 upright, upright piano +882 vacuum, vacuum cleaner +883 vase +884 vault +885 velvet +886 vending machine +887 vestment +888 viaduct +889 violin, fiddle +890 volleyball +891 waffle iron +892 wall clock +893 wallet, billfold, notecase, pocketbook +894 wardrobe, closet, press +895 warplane, military plane +896 washbasin, handbasin, washbowl, lavabo, wash-hand basin +897 washer, automatic washer, washing machine +898 water bottle +899 water jug +900 water tower +901 whiskey jug +902 whistle +903 wig +904 window screen +905 window shade +906 Windsor tie +907 wine bottle +908 wing +909 wok +910 wooden spoon +911 wool, woolen, woollen +912 worm fence, snake fence, snake-rail fence, Virginia fence +913 wreck +914 yawl +915 yurt +916 web site, website, internet site, site +917 comic book +918 crossword puzzle, crossword +919 street sign +920 traffic light, traffic signal, stoplight +921 book jacket, dust cover, dust jacket, dust wrapper +922 menu +923 plate +924 guacamole +925 consomme +926 hot pot, hotpot +927 trifle +928 ice cream, icecream +929 ice lolly, lolly, lollipop, popsicle +930 French loaf +931 bagel, beigel +932 pretzel +933 cheeseburger +934 hotdog, hot dog, red hot +935 mashed potato +936 head cabbage +937 broccoli +938 cauliflower +939 zucchini, courgette +940 spaghetti squash +941 acorn squash +942 butternut squash +943 cucumber, cuke +944 artichoke, globe artichoke +945 bell pepper +946 cardoon +947 mushroom +948 Granny Smith +949 strawberry +950 orange +951 lemon +952 fig +953 pineapple, ananas +954 banana +955 jackfruit, jak, jack +956 custard apple +957 pomegranate +958 hay +959 carbonara +960 chocolate sauce, chocolate syrup +961 dough +962 meat loaf, meatloaf +963 pizza, pizza pie +964 potpie +965 burrito +966 red wine +967 espresso +968 cup +969 eggnog +970 alp +971 bubble +972 cliff, drop, drop-off +973 coral reef +974 geyser +975 lakeside, lakeshore +976 promontory, headland, head, foreland +977 sandbar, sand bar +978 seashore, coast, seacoast, sea-coast +979 valley, vale +980 volcano +981 ballplayer, baseball player +982 groom, bridegroom +983 scuba diver +984 rapeseed +985 daisy +986 yellow lady's slipper, yellow lady-slipper, Cypripedium calceolus, Cypripedium parviflorum +987 corn +988 acorn +989 hip, rose hip, rosehip +990 buckeye, horse chestnut, conker +991 coral fungus +992 agaric +993 gyromitra +994 stinkhorn, carrion fungus +995 earthstar +996 hen-of-the-woods, hen of the woods, Polyporus frondosus, Grifola frondosa +997 bolete +998 ear, spike, capitulum +999 toilet tissue, toilet paper, bathroom tissue diff --git a/pytorch_grad_cam/__init__.py b/pytorch_grad_cam/__init__.py new file mode 100644 index 0000000..4d6e8f3 --- /dev/null +++ b/pytorch_grad_cam/__init__.py @@ -0,0 +1,20 @@ +from pytorch_grad_cam.grad_cam import GradCAM +from pytorch_grad_cam.hirescam import HiResCAM +from pytorch_grad_cam.grad_cam_elementwise import GradCAMElementWise +from pytorch_grad_cam.ablation_layer import AblationLayer, AblationLayerVit, AblationLayerFasterRCNN +from pytorch_grad_cam.ablation_cam import AblationCAM +from pytorch_grad_cam.xgrad_cam import XGradCAM +from pytorch_grad_cam.grad_cam_plusplus import GradCAMPlusPlus +from pytorch_grad_cam.score_cam import ScoreCAM +from pytorch_grad_cam.layer_cam import LayerCAM +from pytorch_grad_cam.eigen_cam import EigenCAM +from pytorch_grad_cam.eigen_grad_cam import EigenGradCAM +from pytorch_grad_cam.random_cam import RandomCAM +from pytorch_grad_cam.fullgrad_cam import FullGrad +from pytorch_grad_cam.guided_backprop import GuidedBackpropReLUModel +from pytorch_grad_cam.activations_and_gradients import ActivationsAndGradients +from pytorch_grad_cam.feature_factorization.deep_feature_factorization import DeepFeatureFactorization, run_dff_on_image +import pytorch_grad_cam.utils.model_targets +import pytorch_grad_cam.utils.reshape_transforms +import pytorch_grad_cam.metrics.cam_mult_image +import pytorch_grad_cam.metrics.road diff --git a/pytorch_grad_cam/__pycache__/__init__.cpython-37.pyc b/pytorch_grad_cam/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000..f550765 Binary files /dev/null and b/pytorch_grad_cam/__pycache__/__init__.cpython-37.pyc differ diff --git a/pytorch_grad_cam/__pycache__/__init__.cpython-38.pyc b/pytorch_grad_cam/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000..556d829 Binary files /dev/null and b/pytorch_grad_cam/__pycache__/__init__.cpython-38.pyc differ diff --git a/pytorch_grad_cam/__pycache__/ablation_cam.cpython-37.pyc b/pytorch_grad_cam/__pycache__/ablation_cam.cpython-37.pyc new file mode 100644 index 0000000..a301e45 Binary files /dev/null and b/pytorch_grad_cam/__pycache__/ablation_cam.cpython-37.pyc differ diff --git a/pytorch_grad_cam/__pycache__/ablation_cam.cpython-38.pyc b/pytorch_grad_cam/__pycache__/ablation_cam.cpython-38.pyc new file mode 100644 index 0000000..caed0c2 Binary files /dev/null and b/pytorch_grad_cam/__pycache__/ablation_cam.cpython-38.pyc differ diff --git a/pytorch_grad_cam/__pycache__/ablation_layer.cpython-37.pyc b/pytorch_grad_cam/__pycache__/ablation_layer.cpython-37.pyc new file mode 100644 index 0000000..04f5a5c Binary files /dev/null and b/pytorch_grad_cam/__pycache__/ablation_layer.cpython-37.pyc differ diff --git a/pytorch_grad_cam/__pycache__/ablation_layer.cpython-38.pyc b/pytorch_grad_cam/__pycache__/ablation_layer.cpython-38.pyc new file mode 100644 index 0000000..4ffd254 Binary files /dev/null and b/pytorch_grad_cam/__pycache__/ablation_layer.cpython-38.pyc differ diff --git a/pytorch_grad_cam/__pycache__/activations_and_gradients.cpython-37.pyc b/pytorch_grad_cam/__pycache__/activations_and_gradients.cpython-37.pyc new file mode 100644 index 0000000..ea35268 Binary files /dev/null and b/pytorch_grad_cam/__pycache__/activations_and_gradients.cpython-37.pyc differ diff --git a/pytorch_grad_cam/__pycache__/activations_and_gradients.cpython-38.pyc b/pytorch_grad_cam/__pycache__/activations_and_gradients.cpython-38.pyc new file mode 100644 index 0000000..d74710b Binary files /dev/null and b/pytorch_grad_cam/__pycache__/activations_and_gradients.cpython-38.pyc differ diff --git a/pytorch_grad_cam/__pycache__/base_cam.cpython-37.pyc b/pytorch_grad_cam/__pycache__/base_cam.cpython-37.pyc new file mode 100644 index 0000000..b7e8d0c Binary files /dev/null and b/pytorch_grad_cam/__pycache__/base_cam.cpython-37.pyc differ diff --git a/pytorch_grad_cam/__pycache__/base_cam.cpython-38.pyc b/pytorch_grad_cam/__pycache__/base_cam.cpython-38.pyc new file mode 100644 index 0000000..6c60dff Binary files /dev/null and b/pytorch_grad_cam/__pycache__/base_cam.cpython-38.pyc differ diff --git a/pytorch_grad_cam/__pycache__/eigen_cam.cpython-37.pyc b/pytorch_grad_cam/__pycache__/eigen_cam.cpython-37.pyc new file mode 100644 index 0000000..113285b Binary files /dev/null and b/pytorch_grad_cam/__pycache__/eigen_cam.cpython-37.pyc differ diff --git a/pytorch_grad_cam/__pycache__/eigen_cam.cpython-38.pyc b/pytorch_grad_cam/__pycache__/eigen_cam.cpython-38.pyc new file mode 100644 index 0000000..8a9ab14 Binary files /dev/null and b/pytorch_grad_cam/__pycache__/eigen_cam.cpython-38.pyc differ diff --git a/pytorch_grad_cam/__pycache__/eigen_grad_cam.cpython-37.pyc b/pytorch_grad_cam/__pycache__/eigen_grad_cam.cpython-37.pyc new file mode 100644 index 0000000..f325b77 Binary files /dev/null and b/pytorch_grad_cam/__pycache__/eigen_grad_cam.cpython-37.pyc differ diff --git a/pytorch_grad_cam/__pycache__/eigen_grad_cam.cpython-38.pyc b/pytorch_grad_cam/__pycache__/eigen_grad_cam.cpython-38.pyc new file mode 100644 index 0000000..bae5a3d Binary files /dev/null and b/pytorch_grad_cam/__pycache__/eigen_grad_cam.cpython-38.pyc differ diff --git a/pytorch_grad_cam/__pycache__/fullgrad_cam.cpython-37.pyc b/pytorch_grad_cam/__pycache__/fullgrad_cam.cpython-37.pyc new file mode 100644 index 0000000..5105c8a Binary files /dev/null and b/pytorch_grad_cam/__pycache__/fullgrad_cam.cpython-37.pyc differ diff --git a/pytorch_grad_cam/__pycache__/fullgrad_cam.cpython-38.pyc b/pytorch_grad_cam/__pycache__/fullgrad_cam.cpython-38.pyc new file mode 100644 index 0000000..2cb48a8 Binary files /dev/null and b/pytorch_grad_cam/__pycache__/fullgrad_cam.cpython-38.pyc differ diff --git a/pytorch_grad_cam/__pycache__/grad_cam.cpython-37.pyc b/pytorch_grad_cam/__pycache__/grad_cam.cpython-37.pyc new file mode 100644 index 0000000..e64d058 Binary files /dev/null and b/pytorch_grad_cam/__pycache__/grad_cam.cpython-37.pyc differ diff --git a/pytorch_grad_cam/__pycache__/grad_cam.cpython-38.pyc b/pytorch_grad_cam/__pycache__/grad_cam.cpython-38.pyc new file mode 100644 index 0000000..b3c13fe Binary files /dev/null and b/pytorch_grad_cam/__pycache__/grad_cam.cpython-38.pyc differ diff --git a/pytorch_grad_cam/__pycache__/grad_cam_elementwise.cpython-37.pyc b/pytorch_grad_cam/__pycache__/grad_cam_elementwise.cpython-37.pyc new file mode 100644 index 0000000..2482d60 Binary files /dev/null and b/pytorch_grad_cam/__pycache__/grad_cam_elementwise.cpython-37.pyc differ diff --git a/pytorch_grad_cam/__pycache__/grad_cam_elementwise.cpython-38.pyc b/pytorch_grad_cam/__pycache__/grad_cam_elementwise.cpython-38.pyc new file mode 100644 index 0000000..3daee8a Binary files /dev/null and b/pytorch_grad_cam/__pycache__/grad_cam_elementwise.cpython-38.pyc differ diff --git a/pytorch_grad_cam/__pycache__/grad_cam_plusplus.cpython-37.pyc b/pytorch_grad_cam/__pycache__/grad_cam_plusplus.cpython-37.pyc new file mode 100644 index 0000000..e8b7d1e Binary files /dev/null and b/pytorch_grad_cam/__pycache__/grad_cam_plusplus.cpython-37.pyc differ diff --git a/pytorch_grad_cam/__pycache__/grad_cam_plusplus.cpython-38.pyc b/pytorch_grad_cam/__pycache__/grad_cam_plusplus.cpython-38.pyc new file mode 100644 index 0000000..988e95f Binary files /dev/null and b/pytorch_grad_cam/__pycache__/grad_cam_plusplus.cpython-38.pyc differ diff --git a/pytorch_grad_cam/__pycache__/guided_backprop.cpython-37.pyc b/pytorch_grad_cam/__pycache__/guided_backprop.cpython-37.pyc new file mode 100644 index 0000000..03571a1 Binary files /dev/null and b/pytorch_grad_cam/__pycache__/guided_backprop.cpython-37.pyc differ diff --git a/pytorch_grad_cam/__pycache__/guided_backprop.cpython-38.pyc b/pytorch_grad_cam/__pycache__/guided_backprop.cpython-38.pyc new file mode 100644 index 0000000..b9ca333 Binary files /dev/null and b/pytorch_grad_cam/__pycache__/guided_backprop.cpython-38.pyc differ diff --git a/pytorch_grad_cam/__pycache__/hirescam.cpython-37.pyc b/pytorch_grad_cam/__pycache__/hirescam.cpython-37.pyc new file mode 100644 index 0000000..4e95b72 Binary files /dev/null and b/pytorch_grad_cam/__pycache__/hirescam.cpython-37.pyc differ diff --git a/pytorch_grad_cam/__pycache__/hirescam.cpython-38.pyc b/pytorch_grad_cam/__pycache__/hirescam.cpython-38.pyc new file mode 100644 index 0000000..680e12e Binary files /dev/null and b/pytorch_grad_cam/__pycache__/hirescam.cpython-38.pyc differ diff --git a/pytorch_grad_cam/__pycache__/layer_cam.cpython-37.pyc b/pytorch_grad_cam/__pycache__/layer_cam.cpython-37.pyc new file mode 100644 index 0000000..e47bddf Binary files /dev/null and b/pytorch_grad_cam/__pycache__/layer_cam.cpython-37.pyc differ diff --git a/pytorch_grad_cam/__pycache__/layer_cam.cpython-38.pyc b/pytorch_grad_cam/__pycache__/layer_cam.cpython-38.pyc new file mode 100644 index 0000000..30022f9 Binary files /dev/null and b/pytorch_grad_cam/__pycache__/layer_cam.cpython-38.pyc differ diff --git a/pytorch_grad_cam/__pycache__/random_cam.cpython-37.pyc b/pytorch_grad_cam/__pycache__/random_cam.cpython-37.pyc new file mode 100644 index 0000000..70ed073 Binary files /dev/null and b/pytorch_grad_cam/__pycache__/random_cam.cpython-37.pyc differ diff --git a/pytorch_grad_cam/__pycache__/random_cam.cpython-38.pyc b/pytorch_grad_cam/__pycache__/random_cam.cpython-38.pyc new file mode 100644 index 0000000..dfcaa56 Binary files /dev/null and b/pytorch_grad_cam/__pycache__/random_cam.cpython-38.pyc differ diff --git a/pytorch_grad_cam/__pycache__/score_cam.cpython-37.pyc b/pytorch_grad_cam/__pycache__/score_cam.cpython-37.pyc new file mode 100644 index 0000000..0c6e7f2 Binary files /dev/null and b/pytorch_grad_cam/__pycache__/score_cam.cpython-37.pyc differ diff --git a/pytorch_grad_cam/__pycache__/score_cam.cpython-38.pyc b/pytorch_grad_cam/__pycache__/score_cam.cpython-38.pyc new file mode 100644 index 0000000..9b4f8a1 Binary files /dev/null and b/pytorch_grad_cam/__pycache__/score_cam.cpython-38.pyc differ diff --git a/pytorch_grad_cam/__pycache__/xgrad_cam.cpython-37.pyc b/pytorch_grad_cam/__pycache__/xgrad_cam.cpython-37.pyc new file mode 100644 index 0000000..1f3946d Binary files /dev/null and b/pytorch_grad_cam/__pycache__/xgrad_cam.cpython-37.pyc differ diff --git a/pytorch_grad_cam/__pycache__/xgrad_cam.cpython-38.pyc b/pytorch_grad_cam/__pycache__/xgrad_cam.cpython-38.pyc new file mode 100644 index 0000000..ed2585a Binary files /dev/null and b/pytorch_grad_cam/__pycache__/xgrad_cam.cpython-38.pyc differ diff --git a/pytorch_grad_cam/ablation_cam.py b/pytorch_grad_cam/ablation_cam.py new file mode 100644 index 0000000..77e65fc --- /dev/null +++ b/pytorch_grad_cam/ablation_cam.py @@ -0,0 +1,148 @@ +import numpy as np +import torch +import tqdm +from typing import Callable, List +from pytorch_grad_cam.base_cam import BaseCAM +from pytorch_grad_cam.utils.find_layers import replace_layer_recursive +from pytorch_grad_cam.ablation_layer import AblationLayer + + +""" Implementation of AblationCAM +https://openaccess.thecvf.com/content_WACV_2020/papers/Desai_Ablation-CAM_Visual_Explanations_for_Deep_Convolutional_Network_via_Gradient-free_Localization_WACV_2020_paper.pdf + +Ablate individual activations, and then measure the drop in the target score. + +In the current implementation, the target layer activations is cached, so it won't be re-computed. +However layers before it, if any, will not be cached. +This means that if the target layer is a large block, for example model.featuers (in vgg), there will +be a large save in run time. + +Since we have to go over many channels and ablate them, and every channel ablation requires a forward pass, +it would be nice if we could avoid doing that for channels that won't contribute anwyay, making it much faster. +The parameter ratio_channels_to_ablate controls how many channels should be ablated, using an experimental method +(to be improved). The default 1.0 value means that all channels will be ablated. +""" + + +class AblationCAM(BaseCAM): + def __init__(self, + model: torch.nn.Module, + target_layers: List[torch.nn.Module], + use_cuda: bool = False, + reshape_transform: Callable = None, + ablation_layer: torch.nn.Module = AblationLayer(), + batch_size: int = 32, + ratio_channels_to_ablate: float = 1.0) -> None: + + super(AblationCAM, self).__init__(model, + target_layers, + use_cuda, + reshape_transform, + uses_gradients=False) + self.batch_size = batch_size + self.ablation_layer = ablation_layer + self.ratio_channels_to_ablate = ratio_channels_to_ablate + + def save_activation(self, module, input, output) -> None: + """ Helper function to save the raw activations from the target layer """ + self.activations = output + + def assemble_ablation_scores(self, + new_scores: list, + original_score: float, + ablated_channels: np.ndarray, + number_of_channels: int) -> np.ndarray: + """ Take the value from the channels that were ablated, + and just set the original score for the channels that were skipped """ + + index = 0 + result = [] + sorted_indices = np.argsort(ablated_channels) + ablated_channels = ablated_channels[sorted_indices] + new_scores = np.float32(new_scores)[sorted_indices] + + for i in range(number_of_channels): + if index < len(ablated_channels) and ablated_channels[index] == i: + weight = new_scores[index] + index = index + 1 + else: + weight = original_score + result.append(weight) + + return result + + def get_cam_weights(self, + input_tensor: torch.Tensor, + target_layer: torch.nn.Module, + targets: List[Callable], + activations: torch.Tensor, + grads: torch.Tensor) -> np.ndarray: + + # Do a forward pass, compute the target scores, and cache the + # activations + handle = target_layer.register_forward_hook(self.save_activation) + with torch.no_grad(): + outputs = self.model(input_tensor) + handle.remove() + original_scores = np.float32( + [target(output).cpu().item() for target, output in zip(targets, outputs)]) + + # Replace the layer with the ablation layer. + # When we finish, we will replace it back, so the original model is + # unchanged. + ablation_layer = self.ablation_layer + replace_layer_recursive(self.model, target_layer, ablation_layer) + + number_of_channels = activations.shape[1] + weights = [] + # This is a "gradient free" method, so we don't need gradients here. + with torch.no_grad(): + # Loop over each of the batch images and ablate activations for it. + for batch_index, (target, tensor) in enumerate( + zip(targets, input_tensor)): + new_scores = [] + batch_tensor = tensor.repeat(self.batch_size, 1, 1, 1) + + # Check which channels should be ablated. Normally this will be all channels, + # But we can also try to speed this up by using a low + # ratio_channels_to_ablate. + channels_to_ablate = ablation_layer.activations_to_be_ablated( + activations[batch_index, :], self.ratio_channels_to_ablate) + number_channels_to_ablate = len(channels_to_ablate) + + for i in tqdm.tqdm( + range( + 0, + number_channels_to_ablate, + self.batch_size)): + if i + self.batch_size > number_channels_to_ablate: + batch_tensor = batch_tensor[:( + number_channels_to_ablate - i)] + + # Change the state of the ablation layer so it ablates the next channels. + # TBD: Move this into the ablation layer forward pass. + ablation_layer.set_next_batch( + input_batch_index=batch_index, + activations=self.activations, + num_channels_to_ablate=batch_tensor.size(0)) + score = [target(o).cpu().item() + for o in self.model(batch_tensor)] + new_scores.extend(score) + ablation_layer.indices = ablation_layer.indices[batch_tensor.size( + 0):] + + new_scores = self.assemble_ablation_scores( + new_scores, + original_scores[batch_index], + channels_to_ablate, + number_of_channels) + weights.extend(new_scores) + + weights = np.float32(weights) + weights = weights.reshape(activations.shape[:2]) + original_scores = original_scores[:, None] + weights = (original_scores - weights) / original_scores + + # Replace the model back to the original state + replace_layer_recursive(self.model, ablation_layer, target_layer) + return weights diff --git a/pytorch_grad_cam/ablation_cam_multilayer.py b/pytorch_grad_cam/ablation_cam_multilayer.py new file mode 100644 index 0000000..9b9dc80 --- /dev/null +++ b/pytorch_grad_cam/ablation_cam_multilayer.py @@ -0,0 +1,136 @@ +import cv2 +import numpy as np +import torch +import tqdm +from pytorch_grad_cam.base_cam import BaseCAM + + +class AblationLayer(torch.nn.Module): + def __init__(self, layer, reshape_transform, indices): + super(AblationLayer, self).__init__() + + self.layer = layer + self.reshape_transform = reshape_transform + # The channels to zero out: + self.indices = indices + + def forward(self, x): + self.__call__(x) + + def __call__(self, x): + output = self.layer(x) + + # Hack to work with ViT, + # Since the activation channels are last and not first like in CNNs + # Probably should remove it? + if self.reshape_transform is not None: + output = output.transpose(1, 2) + + for i in range(output.size(0)): + + # Commonly the minimum activation will be 0, + # And then it makes sense to zero it out. + # However depending on the architecture, + # If the values can be negative, we use very negative values + # to perform the ablation, deviating from the paper. + if torch.min(output) == 0: + output[i, self.indices[i], :] = 0 + else: + ABLATION_VALUE = 1e5 + output[i, self.indices[i], :] = torch.min( + output) - ABLATION_VALUE + + if self.reshape_transform is not None: + output = output.transpose(2, 1) + + return output + + +def replace_layer_recursive(model, old_layer, new_layer): + for name, layer in model._modules.items(): + if layer == old_layer: + model._modules[name] = new_layer + return True + elif replace_layer_recursive(layer, old_layer, new_layer): + return True + return False + + +class AblationCAM(BaseCAM): + def __init__(self, model, target_layers, use_cuda=False, + reshape_transform=None): + super(AblationCAM, self).__init__(model, target_layers, use_cuda, + reshape_transform) + + if len(target_layers) > 1: + print( + "Warning. You are usign Ablation CAM with more than 1 layers. " + "This is supported only if all layers have the same output shape") + + def set_ablation_layers(self): + self.ablation_layers = [] + for target_layer in self.target_layers: + ablation_layer = AblationLayer(target_layer, + self.reshape_transform, indices=[]) + self.ablation_layers.append(ablation_layer) + replace_layer_recursive(self.model, target_layer, ablation_layer) + + def unset_ablation_layers(self): + # replace the model back to the original state + for ablation_layer, target_layer in zip( + self.ablation_layers, self.target_layers): + replace_layer_recursive(self.model, ablation_layer, target_layer) + + def set_ablation_layer_batch_indices(self, indices): + for ablation_layer in self.ablation_layers: + ablation_layer.indices = indices + + def trim_ablation_layer_batch_indices(self, keep): + for ablation_layer in self.ablation_layers: + ablation_layer.indices = ablation_layer.indices[:keep] + + def get_cam_weights(self, + input_tensor, + target_category, + activations, + grads): + with torch.no_grad(): + outputs = self.model(input_tensor).cpu().numpy() + original_scores = [] + for i in range(input_tensor.size(0)): + original_scores.append(outputs[i, target_category[i]]) + original_scores = np.float32(original_scores) + + self.set_ablation_layers() + + if hasattr(self, "batch_size"): + BATCH_SIZE = self.batch_size + else: + BATCH_SIZE = 32 + + number_of_channels = activations.shape[1] + weights = [] + + with torch.no_grad(): + # Iterate over the input batch + for tensor, category in zip(input_tensor, target_category): + batch_tensor = tensor.repeat(BATCH_SIZE, 1, 1, 1) + for i in tqdm.tqdm(range(0, number_of_channels, BATCH_SIZE)): + self.set_ablation_layer_batch_indices( + list(range(i, i + BATCH_SIZE))) + + if i + BATCH_SIZE > number_of_channels: + keep = number_of_channels - i + batch_tensor = batch_tensor[:keep] + self.trim_ablation_layer_batch_indices(self, keep) + score = self.model(batch_tensor)[:, category].cpu().numpy() + weights.extend(score) + + weights = np.float32(weights) + weights = weights.reshape(activations.shape[:2]) + original_scores = original_scores[:, None] + weights = (original_scores - weights) / original_scores + + # replace the model back to the original state + self.unset_ablation_layers() + return weights diff --git a/pytorch_grad_cam/ablation_layer.py b/pytorch_grad_cam/ablation_layer.py new file mode 100644 index 0000000..b404f3b --- /dev/null +++ b/pytorch_grad_cam/ablation_layer.py @@ -0,0 +1,155 @@ +import torch +from collections import OrderedDict +import numpy as np +from pytorch_grad_cam.utils.svd_on_activations import get_2d_projection + + +class AblationLayer(torch.nn.Module): + def __init__(self): + super(AblationLayer, self).__init__() + + def objectiveness_mask_from_svd(self, activations, threshold=0.01): + """ Experimental method to get a binary mask to compare if the activation is worth ablating. + The idea is to apply the EigenCAM method by doing PCA on the activations. + Then we create a binary mask by comparing to a low threshold. + Areas that are masked out, are probably not interesting anyway. + """ + + projection = get_2d_projection(activations[None, :])[0, :] + projection = np.abs(projection) + projection = projection - projection.min() + projection = projection / projection.max() + projection = projection > threshold + return projection + + def activations_to_be_ablated( + self, + activations, + ratio_channels_to_ablate=1.0): + """ Experimental method to get a binary mask to compare if the activation is worth ablating. + Create a binary CAM mask with objectiveness_mask_from_svd. + Score each Activation channel, by seeing how much of its values are inside the mask. + Then keep the top channels. + + """ + if ratio_channels_to_ablate == 1.0: + self.indices = np.int32(range(activations.shape[0])) + return self.indices + + projection = self.objectiveness_mask_from_svd(activations) + + scores = [] + for channel in activations: + normalized = np.abs(channel) + normalized = normalized - normalized.min() + normalized = normalized / np.max(normalized) + score = (projection * normalized).sum() / normalized.sum() + scores.append(score) + scores = np.float32(scores) + + indices = list(np.argsort(scores)) + high_score_indices = indices[::- + 1][: int(len(indices) * + ratio_channels_to_ablate)] + low_score_indices = indices[: int( + len(indices) * ratio_channels_to_ablate)] + self.indices = np.int32(high_score_indices + low_score_indices) + return self.indices + + def set_next_batch( + self, + input_batch_index, + activations, + num_channels_to_ablate): + """ This creates the next batch of activations from the layer. + Just take corresponding batch member from activations, and repeat it num_channels_to_ablate times. + """ + self.activations = activations[input_batch_index, :, :, :].clone( + ).unsqueeze(0).repeat(num_channels_to_ablate, 1, 1, 1) + + def __call__(self, x): + output = self.activations + for i in range(output.size(0)): + # Commonly the minimum activation will be 0, + # And then it makes sense to zero it out. + # However depending on the architecture, + # If the values can be negative, we use very negative values + # to perform the ablation, deviating from the paper. + if torch.min(output) == 0: + output[i, self.indices[i], :] = 0 + else: + ABLATION_VALUE = 1e7 + output[i, self.indices[i], :] = torch.min( + output) - ABLATION_VALUE + + return output + + +class AblationLayerVit(AblationLayer): + def __init__(self): + super(AblationLayerVit, self).__init__() + + def __call__(self, x): + output = self.activations + output = output.transpose(1, len(output.shape) - 1) + for i in range(output.size(0)): + + # Commonly the minimum activation will be 0, + # And then it makes sense to zero it out. + # However depending on the architecture, + # If the values can be negative, we use very negative values + # to perform the ablation, deviating from the paper. + if torch.min(output) == 0: + output[i, self.indices[i], :] = 0 + else: + ABLATION_VALUE = 1e7 + output[i, self.indices[i], :] = torch.min( + output) - ABLATION_VALUE + + output = output.transpose(len(output.shape) - 1, 1) + + return output + + def set_next_batch( + self, + input_batch_index, + activations, + num_channels_to_ablate): + """ This creates the next batch of activations from the layer. + Just take corresponding batch member from activations, and repeat it num_channels_to_ablate times. + """ + repeat_params = [num_channels_to_ablate] + \ + len(activations.shape[:-1]) * [1] + self.activations = activations[input_batch_index, :, :].clone( + ).unsqueeze(0).repeat(*repeat_params) + + +class AblationLayerFasterRCNN(AblationLayer): + def __init__(self): + super(AblationLayerFasterRCNN, self).__init__() + + def set_next_batch( + self, + input_batch_index, + activations, + num_channels_to_ablate): + """ Extract the next batch member from activations, + and repeat it num_channels_to_ablate times. + """ + self.activations = OrderedDict() + for key, value in activations.items(): + fpn_activation = value[input_batch_index, + :, :, :].clone().unsqueeze(0) + self.activations[key] = fpn_activation.repeat( + num_channels_to_ablate, 1, 1, 1) + + def __call__(self, x): + result = self.activations + layers = {0: '0', 1: '1', 2: '2', 3: '3', 4: 'pool'} + num_channels_to_ablate = result['pool'].size(0) + for i in range(num_channels_to_ablate): + pyramid_layer = int(self.indices[i] / 256) + index_in_pyramid_layer = int(self.indices[i] % 256) + result[layers[pyramid_layer]][i, + index_in_pyramid_layer, :, :] = -1000 + return result diff --git a/pytorch_grad_cam/activations_and_gradients.py b/pytorch_grad_cam/activations_and_gradients.py new file mode 100644 index 0000000..0c2071e --- /dev/null +++ b/pytorch_grad_cam/activations_and_gradients.py @@ -0,0 +1,46 @@ +class ActivationsAndGradients: + """ Class for extracting activations and + registering gradients from targetted intermediate layers """ + + def __init__(self, model, target_layers, reshape_transform): + self.model = model + self.gradients = [] + self.activations = [] + self.reshape_transform = reshape_transform + self.handles = [] + for target_layer in target_layers: + self.handles.append( + target_layer.register_forward_hook(self.save_activation)) + # Because of https://github.com/pytorch/pytorch/issues/61519, + # we don't use backward hook to record gradients. + self.handles.append( + target_layer.register_forward_hook(self.save_gradient)) + + def save_activation(self, module, input, output): + activation = output + + if self.reshape_transform is not None: + activation = self.reshape_transform(activation) + self.activations.append(activation.cpu().detach()) + + def save_gradient(self, module, input, output): + if not hasattr(output, "requires_grad") or not output.requires_grad: + # You can only register hooks on tensor requires grad. + return + + # Gradients are computed in reverse order + def _store_grad(grad): + if self.reshape_transform is not None: + grad = self.reshape_transform(grad) + self.gradients = [grad.cpu().detach()] + self.gradients + + output.register_hook(_store_grad) + + def __call__(self, x): + self.gradients = [] + self.activations = [] + return self.model(x) + + def release(self): + for handle in self.handles: + handle.remove() diff --git a/pytorch_grad_cam/base_cam.py b/pytorch_grad_cam/base_cam.py new file mode 100644 index 0000000..7ee1929 --- /dev/null +++ b/pytorch_grad_cam/base_cam.py @@ -0,0 +1,203 @@ +import numpy as np +import torch +import ttach as tta +from typing import Callable, List, Tuple +from pytorch_grad_cam.activations_and_gradients import ActivationsAndGradients +from pytorch_grad_cam.utils.svd_on_activations import get_2d_projection +from pytorch_grad_cam.utils.image import scale_cam_image +from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget + + +class BaseCAM: + def __init__(self, + model: torch.nn.Module, + target_layers: List[torch.nn.Module], + use_cuda: bool = False, + reshape_transform: Callable = None, + compute_input_gradient: bool = False, + uses_gradients: bool = True) -> None: + self.model = model.eval() + self.target_layers = target_layers + self.cuda = use_cuda + if self.cuda: + self.model = model.cuda() + self.reshape_transform = reshape_transform + self.compute_input_gradient = compute_input_gradient + self.uses_gradients = uses_gradients + self.activations_and_grads = ActivationsAndGradients( + self.model, target_layers, reshape_transform) + + """ Get a vector of weights for every channel in the target layer. + Methods that return weights channels, + will typically need to only implement this function. """ + + def get_cam_weights(self, + input_tensor: torch.Tensor, + target_layers: List[torch.nn.Module], + targets: List[torch.nn.Module], + activations: torch.Tensor, + grads: torch.Tensor) -> np.ndarray: + raise Exception("Not Implemented") + + def get_cam_image(self, + input_tensor: torch.Tensor, + target_layer: torch.nn.Module, + targets: List[torch.nn.Module], + activations: torch.Tensor, + grads: torch.Tensor, + eigen_smooth: bool = False) -> np.ndarray: + + weights = self.get_cam_weights(input_tensor, + target_layer, + targets, + activations, + grads) + weighted_activations = weights[:, :, None, None] * activations + if eigen_smooth: + cam = get_2d_projection(weighted_activations) + else: + cam = weighted_activations.sum(axis=1) + return cam + + def forward(self, + input_tensor: torch.Tensor, + targets: List[torch.nn.Module], + eigen_smooth: bool = False) -> np.ndarray: + + if self.cuda: + input_tensor = input_tensor.cuda() + + if self.compute_input_gradient: + input_tensor = torch.autograd.Variable(input_tensor, + requires_grad=True) + + outputs = self.activations_and_grads(input_tensor) + if targets is None: + target_categories = np.argmax(outputs.cpu().data.numpy(), axis=-1) + targets = [ClassifierOutputTarget( + category) for category in target_categories] + + if self.uses_gradients: + self.model.zero_grad() + loss = sum([target(output) + for target, output in zip(targets, outputs)]) + loss.backward(retain_graph=True) + + # In most of the saliency attribution papers, the saliency is + # computed with a single target layer. + # Commonly it is the last convolutional layer. + # Here we support passing a list with multiple target layers. + # It will compute the saliency image for every image, + # and then aggregate them (with a default mean aggregation). + # This gives you more flexibility in case you just want to + # use all conv layers for example, all Batchnorm layers, + # or something else. + cam_per_layer = self.compute_cam_per_layer(input_tensor, + targets, + eigen_smooth) + return self.aggregate_multi_layers(cam_per_layer) + + def get_target_width_height(self, + input_tensor: torch.Tensor) -> Tuple[int, int]: + width, height = input_tensor.size(-1), input_tensor.size(-2) + return width, height + + def compute_cam_per_layer( + self, + input_tensor: torch.Tensor, + targets: List[torch.nn.Module], + eigen_smooth: bool) -> np.ndarray: + activations_list = [a.cpu().data.numpy() + for a in self.activations_and_grads.activations] + grads_list = [g.cpu().data.numpy() + for g in self.activations_and_grads.gradients] + target_size = self.get_target_width_height(input_tensor) + + cam_per_target_layer = [] + # Loop over the saliency image from every layer + for i in range(len(self.target_layers)): + target_layer = self.target_layers[i] + layer_activations = None + layer_grads = None + if i < len(activations_list): + layer_activations = activations_list[i] + if i < len(grads_list): + layer_grads = grads_list[i] + + cam = self.get_cam_image(input_tensor, + target_layer, + targets, + layer_activations, + layer_grads, + eigen_smooth) + cam = np.maximum(cam, 0) + scaled = scale_cam_image(cam, target_size) + cam_per_target_layer.append(scaled[:, None, :]) + + return cam_per_target_layer + + def aggregate_multi_layers( + self, + cam_per_target_layer: np.ndarray) -> np.ndarray: + cam_per_target_layer = np.concatenate(cam_per_target_layer, axis=1) + cam_per_target_layer = np.maximum(cam_per_target_layer, 0) + result = np.mean(cam_per_target_layer, axis=1) + return scale_cam_image(result) + + def forward_augmentation_smoothing(self, + input_tensor: torch.Tensor, + targets: List[torch.nn.Module], + eigen_smooth: bool = False) -> np.ndarray: + transforms = tta.Compose( + [ + tta.HorizontalFlip(), + tta.Multiply(factors=[0.9, 1, 1.1]), + ] + ) + cams = [] + for transform in transforms: + augmented_tensor = transform.augment_image(input_tensor) + cam = self.forward(augmented_tensor, + targets, + eigen_smooth) + + # The ttach library expects a tensor of size BxCxHxW + cam = cam[:, None, :, :] + cam = torch.from_numpy(cam) + cam = transform.deaugment_mask(cam) + + # Back to numpy float32, HxW + cam = cam.numpy() + cam = cam[:, 0, :, :] + cams.append(cam) + + cam = np.mean(np.float32(cams), axis=0) + return cam + + def __call__(self, + input_tensor: torch.Tensor, + targets: List[torch.nn.Module] = None, + aug_smooth: bool = False, + eigen_smooth: bool = False) -> np.ndarray: + + # Smooth the CAM result with test time augmentation + if aug_smooth is True: + return self.forward_augmentation_smoothing( + input_tensor, targets, eigen_smooth) + + return self.forward(input_tensor, + targets, eigen_smooth) + + def __del__(self): + self.activations_and_grads.release() + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, exc_tb): + self.activations_and_grads.release() + if isinstance(exc_value, IndexError): + # Handle IndexError here... + print( + f"An exception occurred in CAM with block: {exc_type}. Message: {exc_value}") + return True diff --git a/pytorch_grad_cam/eigen_cam.py b/pytorch_grad_cam/eigen_cam.py new file mode 100644 index 0000000..fd6d6bc --- /dev/null +++ b/pytorch_grad_cam/eigen_cam.py @@ -0,0 +1,23 @@ +from pytorch_grad_cam.base_cam import BaseCAM +from pytorch_grad_cam.utils.svd_on_activations import get_2d_projection + +# https://arxiv.org/abs/2008.00299 + + +class EigenCAM(BaseCAM): + def __init__(self, model, target_layers, use_cuda=False, + reshape_transform=None): + super(EigenCAM, self).__init__(model, + target_layers, + use_cuda, + reshape_transform, + uses_gradients=False) + + def get_cam_image(self, + input_tensor, + target_layer, + target_category, + activations, + grads, + eigen_smooth): + return get_2d_projection(activations) diff --git a/pytorch_grad_cam/eigen_grad_cam.py b/pytorch_grad_cam/eigen_grad_cam.py new file mode 100644 index 0000000..3932a96 --- /dev/null +++ b/pytorch_grad_cam/eigen_grad_cam.py @@ -0,0 +1,21 @@ +from pytorch_grad_cam.base_cam import BaseCAM +from pytorch_grad_cam.utils.svd_on_activations import get_2d_projection + +# Like Eigen CAM: https://arxiv.org/abs/2008.00299 +# But multiply the activations x gradients + + +class EigenGradCAM(BaseCAM): + def __init__(self, model, target_layers, use_cuda=False, + reshape_transform=None): + super(EigenGradCAM, self).__init__(model, target_layers, use_cuda, + reshape_transform) + + def get_cam_image(self, + input_tensor, + target_layer, + target_category, + activations, + grads, + eigen_smooth): + return get_2d_projection(grads * activations) diff --git a/pytorch_grad_cam/feature_factorization/__init__.py b/pytorch_grad_cam/feature_factorization/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/pytorch_grad_cam/feature_factorization/__pycache__/__init__.cpython-37.pyc b/pytorch_grad_cam/feature_factorization/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000..6b20a42 Binary files /dev/null and b/pytorch_grad_cam/feature_factorization/__pycache__/__init__.cpython-37.pyc differ diff --git a/pytorch_grad_cam/feature_factorization/__pycache__/__init__.cpython-38.pyc b/pytorch_grad_cam/feature_factorization/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000..6b3636f Binary files /dev/null and b/pytorch_grad_cam/feature_factorization/__pycache__/__init__.cpython-38.pyc differ diff --git a/pytorch_grad_cam/feature_factorization/__pycache__/deep_feature_factorization.cpython-37.pyc b/pytorch_grad_cam/feature_factorization/__pycache__/deep_feature_factorization.cpython-37.pyc new file mode 100644 index 0000000..0b22a48 Binary files /dev/null and b/pytorch_grad_cam/feature_factorization/__pycache__/deep_feature_factorization.cpython-37.pyc differ diff --git a/pytorch_grad_cam/feature_factorization/__pycache__/deep_feature_factorization.cpython-38.pyc b/pytorch_grad_cam/feature_factorization/__pycache__/deep_feature_factorization.cpython-38.pyc new file mode 100644 index 0000000..2f61f2f Binary files /dev/null and b/pytorch_grad_cam/feature_factorization/__pycache__/deep_feature_factorization.cpython-38.pyc differ diff --git a/pytorch_grad_cam/feature_factorization/deep_feature_factorization.py b/pytorch_grad_cam/feature_factorization/deep_feature_factorization.py new file mode 100644 index 0000000..b9db2c3 --- /dev/null +++ b/pytorch_grad_cam/feature_factorization/deep_feature_factorization.py @@ -0,0 +1,131 @@ +import numpy as np +from PIL import Image +import torch +from typing import Callable, List, Tuple, Optional +from sklearn.decomposition import NMF +from pytorch_grad_cam.activations_and_gradients import ActivationsAndGradients +from pytorch_grad_cam.utils.image import scale_cam_image, create_labels_legend, show_factorization_on_image + + +def dff(activations: np.ndarray, n_components: int = 5): + """ Compute Deep Feature Factorization on a 2d Activations tensor. + + :param activations: A numpy array of shape batch x channels x height x width + :param n_components: The number of components for the non negative matrix factorization + :returns: A tuple of the concepts (a numpy array with shape channels x components), + and the explanation heatmaps (a numpy arary with shape batch x height x width) + """ + + batch_size, channels, h, w = activations.shape + reshaped_activations = activations.transpose((1, 0, 2, 3)) + reshaped_activations[np.isnan(reshaped_activations)] = 0 + reshaped_activations = reshaped_activations.reshape( + reshaped_activations.shape[0], -1) + offset = reshaped_activations.min(axis=-1) + reshaped_activations = reshaped_activations - offset[:, None] + + model = NMF(n_components=n_components, init='random', random_state=0) + W = model.fit_transform(reshaped_activations) + H = model.components_ + concepts = W + offset[:, None] + explanations = H.reshape(n_components, batch_size, h, w) + explanations = explanations.transpose((1, 0, 2, 3)) + return concepts, explanations + + +class DeepFeatureFactorization: + """ Deep Feature Factorization: https://arxiv.org/abs/1806.10206 + This gets a model andcomputes the 2D activations for a target layer, + and computes Non Negative Matrix Factorization on the activations. + + Optionally it runs a computation on the concept embeddings, + like running a classifier on them. + + The explanation heatmaps are scalled to the range [0, 1] + and to the input tensor width and height. + """ + + def __init__(self, + model: torch.nn.Module, + target_layer: torch.nn.Module, + reshape_transform: Callable = None, + computation_on_concepts=None + ): + self.model = model + self.computation_on_concepts = computation_on_concepts + self.activations_and_grads = ActivationsAndGradients( + self.model, [target_layer], reshape_transform) + + def __call__(self, + input_tensor: torch.Tensor, + n_components: int = 16): + batch_size, channels, h, w = input_tensor.size() + _ = self.activations_and_grads(input_tensor) + + with torch.no_grad(): + activations = self.activations_and_grads.activations[0].cpu( + ).numpy() + + concepts, explanations = dff(activations, n_components=n_components) + + processed_explanations = [] + + for batch in explanations: + processed_explanations.append(scale_cam_image(batch, (w, h))) + + if self.computation_on_concepts: + with torch.no_grad(): + concept_tensors = torch.from_numpy( + np.float32(concepts).transpose((1, 0))) + concept_outputs = self.computation_on_concepts( + concept_tensors).cpu().numpy() + return concepts, processed_explanations, concept_outputs + else: + return concepts, processed_explanations + + def __del__(self): + self.activations_and_grads.release() + + def __exit__(self, exc_type, exc_value, exc_tb): + self.activations_and_grads.release() + if isinstance(exc_value, IndexError): + # Handle IndexError here... + print( + f"An exception occurred in ActivationSummary with block: {exc_type}. Message: {exc_value}") + return True + + +def run_dff_on_image(model: torch.nn.Module, + target_layer: torch.nn.Module, + classifier: torch.nn.Module, + img_pil: Image, + img_tensor: torch.Tensor, + reshape_transform=Optional[Callable], + n_components: int = 5, + top_k: int = 2) -> np.ndarray: + """ Helper function to create a Deep Feature Factorization visualization for a single image. + TBD: Run this on a batch with several images. + """ + rgb_img_float = np.array(img_pil) / 255 + dff = DeepFeatureFactorization(model=model, + reshape_transform=reshape_transform, + target_layer=target_layer, + computation_on_concepts=classifier) + + concepts, batch_explanations, concept_outputs = dff( + img_tensor[None, :], n_components) + + concept_outputs = torch.softmax( + torch.from_numpy(concept_outputs), + axis=-1).numpy() + concept_label_strings = create_labels_legend(concept_outputs, + labels=model.config.id2label, + top_k=top_k) + visualization = show_factorization_on_image( + rgb_img_float, + batch_explanations[0], + image_weight=0.3, + concept_labels=concept_label_strings) + + result = np.hstack((np.array(img_pil), visualization)) + return result diff --git a/pytorch_grad_cam/fullgrad_cam.py b/pytorch_grad_cam/fullgrad_cam.py new file mode 100644 index 0000000..1a2685e --- /dev/null +++ b/pytorch_grad_cam/fullgrad_cam.py @@ -0,0 +1,95 @@ +import numpy as np +import torch +from pytorch_grad_cam.base_cam import BaseCAM +from pytorch_grad_cam.utils.find_layers import find_layer_predicate_recursive +from pytorch_grad_cam.utils.svd_on_activations import get_2d_projection +from pytorch_grad_cam.utils.image import scale_accross_batch_and_channels, scale_cam_image + +# https://arxiv.org/abs/1905.00780 + + +class FullGrad(BaseCAM): + def __init__(self, model, target_layers, use_cuda=False, + reshape_transform=None): + if len(target_layers) > 0: + print( + "Warning: target_layers is ignored in FullGrad. All bias layers will be used instead") + + def layer_with_2D_bias(layer): + bias_target_layers = [torch.nn.Conv2d, torch.nn.BatchNorm2d] + if type(layer) in bias_target_layers and layer.bias is not None: + return True + return False + target_layers = find_layer_predicate_recursive( + model, layer_with_2D_bias) + super( + FullGrad, + self).__init__( + model, + target_layers, + use_cuda, + reshape_transform, + compute_input_gradient=True) + self.bias_data = [self.get_bias_data( + layer).cpu().numpy() for layer in target_layers] + + def get_bias_data(self, layer): + # Borrowed from official paper impl: + # https://github.com/idiap/fullgrad-saliency/blob/master/saliency/tensor_extractor.py#L47 + if isinstance(layer, torch.nn.BatchNorm2d): + bias = - (layer.running_mean * layer.weight + / torch.sqrt(layer.running_var + layer.eps)) + layer.bias + return bias.data + else: + return layer.bias.data + + def compute_cam_per_layer( + self, + input_tensor, + target_category, + eigen_smooth): + input_grad = input_tensor.grad.data.cpu().numpy() + grads_list = [g.cpu().data.numpy() for g in + self.activations_and_grads.gradients] + cam_per_target_layer = [] + target_size = self.get_target_width_height(input_tensor) + + gradient_multiplied_input = input_grad * input_tensor.data.cpu().numpy() + gradient_multiplied_input = np.abs(gradient_multiplied_input) + gradient_multiplied_input = scale_accross_batch_and_channels( + gradient_multiplied_input, + target_size) + cam_per_target_layer.append(gradient_multiplied_input) + + # Loop over the saliency image from every layer + assert(len(self.bias_data) == len(grads_list)) + for bias, grads in zip(self.bias_data, grads_list): + bias = bias[None, :, None, None] + # In the paper they take the absolute value, + # but possibily taking only the positive gradients will work + # better. + bias_grad = np.abs(bias * grads) + result = scale_accross_batch_and_channels( + bias_grad, target_size) + result = np.sum(result, axis=1) + cam_per_target_layer.append(result[:, None, :]) + cam_per_target_layer = np.concatenate(cam_per_target_layer, axis=1) + if eigen_smooth: + # Resize to a smaller image, since this method typically has a very large number of channels, + # and then consumes a lot of memory + cam_per_target_layer = scale_accross_batch_and_channels( + cam_per_target_layer, (target_size[0] // 8, target_size[1] // 8)) + cam_per_target_layer = get_2d_projection(cam_per_target_layer) + cam_per_target_layer = cam_per_target_layer[:, None, :, :] + cam_per_target_layer = scale_accross_batch_and_channels( + cam_per_target_layer, + target_size) + else: + cam_per_target_layer = np.sum( + cam_per_target_layer, axis=1)[:, None, :] + + return cam_per_target_layer + + def aggregate_multi_layers(self, cam_per_target_layer): + result = np.sum(cam_per_target_layer, axis=1) + return scale_cam_image(result) diff --git a/pytorch_grad_cam/grad_cam.py b/pytorch_grad_cam/grad_cam.py new file mode 100644 index 0000000..025bf45 --- /dev/null +++ b/pytorch_grad_cam/grad_cam.py @@ -0,0 +1,22 @@ +import numpy as np +from pytorch_grad_cam.base_cam import BaseCAM + + +class GradCAM(BaseCAM): + def __init__(self, model, target_layers, use_cuda=False, + reshape_transform=None): + super( + GradCAM, + self).__init__( + model, + target_layers, + use_cuda, + reshape_transform) + + def get_cam_weights(self, + input_tensor, + target_layer, + target_category, + activations, + grads): + return np.mean(grads, axis=(2, 3)) diff --git a/pytorch_grad_cam/grad_cam_elementwise.py b/pytorch_grad_cam/grad_cam_elementwise.py new file mode 100644 index 0000000..2698d47 --- /dev/null +++ b/pytorch_grad_cam/grad_cam_elementwise.py @@ -0,0 +1,30 @@ +import numpy as np +from pytorch_grad_cam.base_cam import BaseCAM +from pytorch_grad_cam.utils.svd_on_activations import get_2d_projection + + +class GradCAMElementWise(BaseCAM): + def __init__(self, model, target_layers, use_cuda=False, + reshape_transform=None): + super( + GradCAMElementWise, + self).__init__( + model, + target_layers, + use_cuda, + reshape_transform) + + def get_cam_image(self, + input_tensor, + target_layer, + target_category, + activations, + grads, + eigen_smooth): + elementwise_activations = np.maximum(grads * activations, 0) + + if eigen_smooth: + cam = get_2d_projection(elementwise_activations) + else: + cam = elementwise_activations.sum(axis=1) + return cam diff --git a/pytorch_grad_cam/grad_cam_plusplus.py b/pytorch_grad_cam/grad_cam_plusplus.py new file mode 100644 index 0000000..4466826 --- /dev/null +++ b/pytorch_grad_cam/grad_cam_plusplus.py @@ -0,0 +1,32 @@ +import numpy as np +from pytorch_grad_cam.base_cam import BaseCAM + +# https://arxiv.org/abs/1710.11063 + + +class GradCAMPlusPlus(BaseCAM): + def __init__(self, model, target_layers, use_cuda=False, + reshape_transform=None): + super(GradCAMPlusPlus, self).__init__(model, target_layers, use_cuda, + reshape_transform) + + def get_cam_weights(self, + input_tensor, + target_layers, + target_category, + activations, + grads): + grads_power_2 = grads**2 + grads_power_3 = grads_power_2 * grads + # Equation 19 in https://arxiv.org/abs/1710.11063 + sum_activations = np.sum(activations, axis=(2, 3)) + eps = 0.000001 + aij = grads_power_2 / (2 * grads_power_2 + + sum_activations[:, :, None, None] * grads_power_3 + eps) + # Now bring back the ReLU from eq.7 in the paper, + # And zero out aijs where the activations are 0 + aij = np.where(grads != 0, aij, 0) + + weights = np.maximum(grads, 0) * aij + weights = np.sum(weights, axis=(2, 3)) + return weights diff --git a/pytorch_grad_cam/guided_backprop.py b/pytorch_grad_cam/guided_backprop.py new file mode 100644 index 0000000..602fbf3 --- /dev/null +++ b/pytorch_grad_cam/guided_backprop.py @@ -0,0 +1,100 @@ +import numpy as np +import torch +from torch.autograd import Function +from pytorch_grad_cam.utils.find_layers import replace_all_layer_type_recursive + + +class GuidedBackpropReLU(Function): + @staticmethod + def forward(self, input_img): + positive_mask = (input_img > 0).type_as(input_img) + output = torch.addcmul( + torch.zeros( + input_img.size()).type_as(input_img), + input_img, + positive_mask) + self.save_for_backward(input_img, output) + return output + + @staticmethod + def backward(self, grad_output): + input_img, output = self.saved_tensors + grad_input = None + + positive_mask_1 = (input_img > 0).type_as(grad_output) + positive_mask_2 = (grad_output > 0).type_as(grad_output) + grad_input = torch.addcmul( + torch.zeros( + input_img.size()).type_as(input_img), + torch.addcmul( + torch.zeros( + input_img.size()).type_as(input_img), + grad_output, + positive_mask_1), + positive_mask_2) + return grad_input + + +class GuidedBackpropReLUasModule(torch.nn.Module): + def __init__(self): + super(GuidedBackpropReLUasModule, self).__init__() + + def forward(self, input_img): + return GuidedBackpropReLU.apply(input_img) + + +class GuidedBackpropReLUModel: + def __init__(self, model, use_cuda): + self.model = model + self.model.eval() + self.cuda = use_cuda + if self.cuda: + self.model = self.model.cuda() + + def forward(self, input_img): + return self.model(input_img) + + def recursive_replace_relu_with_guidedrelu(self, module_top): + + for idx, module in module_top._modules.items(): + self.recursive_replace_relu_with_guidedrelu(module) + if module.__class__.__name__ == 'ReLU': + module_top._modules[idx] = GuidedBackpropReLU.apply + print("b") + + def recursive_replace_guidedrelu_with_relu(self, module_top): + try: + for idx, module in module_top._modules.items(): + self.recursive_replace_guidedrelu_with_relu(module) + if module == GuidedBackpropReLU.apply: + module_top._modules[idx] = torch.nn.ReLU() + except BaseException: + pass + + def __call__(self, input_img, target_category=None): + replace_all_layer_type_recursive(self.model, + torch.nn.ReLU, + GuidedBackpropReLUasModule()) + + if self.cuda: + input_img = input_img.cuda() + + input_img = input_img.requires_grad_(True) + + output = self.forward(input_img) + + if target_category is None: + target_category = np.argmax(output.cpu().data.numpy()) + + loss = output[0, target_category] + loss.backward(retain_graph=True) + + output = input_img.grad.cpu().data.numpy() + output = output[0, :, :, :] + output = output.transpose((1, 2, 0)) + + replace_all_layer_type_recursive(self.model, + GuidedBackpropReLUasModule, + torch.nn.ReLU()) + + return output diff --git a/pytorch_grad_cam/hirescam.py b/pytorch_grad_cam/hirescam.py new file mode 100644 index 0000000..381d8d4 --- /dev/null +++ b/pytorch_grad_cam/hirescam.py @@ -0,0 +1,32 @@ +import numpy as np +from pytorch_grad_cam.base_cam import BaseCAM +from pytorch_grad_cam.utils.svd_on_activations import get_2d_projection + + +class HiResCAM(BaseCAM): + def __init__(self, model, target_layers, use_cuda=False, + reshape_transform=None): + super( + HiResCAM, + self).__init__( + model, + target_layers, + use_cuda, + reshape_transform) + + def get_cam_image(self, + input_tensor, + target_layer, + target_category, + activations, + grads, + eigen_smooth): + elementwise_activations = grads * activations + + if eigen_smooth: + print( + "Warning: HiResCAM's faithfulness guarantees do not hold if smoothing is applied") + cam = get_2d_projection(elementwise_activations) + else: + cam = elementwise_activations.sum(axis=1) + return cam diff --git a/pytorch_grad_cam/layer_cam.py b/pytorch_grad_cam/layer_cam.py new file mode 100644 index 0000000..971443d --- /dev/null +++ b/pytorch_grad_cam/layer_cam.py @@ -0,0 +1,36 @@ +import numpy as np +from pytorch_grad_cam.base_cam import BaseCAM +from pytorch_grad_cam.utils.svd_on_activations import get_2d_projection + +# https://ieeexplore.ieee.org/document/9462463 + + +class LayerCAM(BaseCAM): + def __init__( + self, + model, + target_layers, + use_cuda=False, + reshape_transform=None): + super( + LayerCAM, + self).__init__( + model, + target_layers, + use_cuda, + reshape_transform) + + def get_cam_image(self, + input_tensor, + target_layer, + target_category, + activations, + grads, + eigen_smooth): + spatial_weighted_activations = np.maximum(grads, 0) * activations + + if eigen_smooth: + cam = get_2d_projection(spatial_weighted_activations) + else: + cam = spatial_weighted_activations.sum(axis=1) + return cam diff --git a/pytorch_grad_cam/metrics/__init__.py b/pytorch_grad_cam/metrics/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/pytorch_grad_cam/metrics/__pycache__/__init__.cpython-37.pyc b/pytorch_grad_cam/metrics/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000..9278b11 Binary files /dev/null and b/pytorch_grad_cam/metrics/__pycache__/__init__.cpython-37.pyc differ diff --git a/pytorch_grad_cam/metrics/__pycache__/__init__.cpython-38.pyc b/pytorch_grad_cam/metrics/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000..6ec40fa Binary files /dev/null and b/pytorch_grad_cam/metrics/__pycache__/__init__.cpython-38.pyc differ diff --git a/pytorch_grad_cam/metrics/__pycache__/cam_mult_image.cpython-37.pyc b/pytorch_grad_cam/metrics/__pycache__/cam_mult_image.cpython-37.pyc new file mode 100644 index 0000000..4239082 Binary files /dev/null and b/pytorch_grad_cam/metrics/__pycache__/cam_mult_image.cpython-37.pyc differ diff --git a/pytorch_grad_cam/metrics/__pycache__/cam_mult_image.cpython-38.pyc b/pytorch_grad_cam/metrics/__pycache__/cam_mult_image.cpython-38.pyc new file mode 100644 index 0000000..cbf2ee5 Binary files /dev/null and b/pytorch_grad_cam/metrics/__pycache__/cam_mult_image.cpython-38.pyc differ diff --git a/pytorch_grad_cam/metrics/__pycache__/perturbation_confidence.cpython-37.pyc b/pytorch_grad_cam/metrics/__pycache__/perturbation_confidence.cpython-37.pyc new file mode 100644 index 0000000..b1b9b2e Binary files /dev/null and b/pytorch_grad_cam/metrics/__pycache__/perturbation_confidence.cpython-37.pyc differ diff --git a/pytorch_grad_cam/metrics/__pycache__/perturbation_confidence.cpython-38.pyc b/pytorch_grad_cam/metrics/__pycache__/perturbation_confidence.cpython-38.pyc new file mode 100644 index 0000000..e788488 Binary files /dev/null and b/pytorch_grad_cam/metrics/__pycache__/perturbation_confidence.cpython-38.pyc differ diff --git a/pytorch_grad_cam/metrics/__pycache__/road.cpython-37.pyc b/pytorch_grad_cam/metrics/__pycache__/road.cpython-37.pyc new file mode 100644 index 0000000..83b5e13 Binary files /dev/null and b/pytorch_grad_cam/metrics/__pycache__/road.cpython-37.pyc differ diff --git a/pytorch_grad_cam/metrics/__pycache__/road.cpython-38.pyc b/pytorch_grad_cam/metrics/__pycache__/road.cpython-38.pyc new file mode 100644 index 0000000..309afc7 Binary files /dev/null and b/pytorch_grad_cam/metrics/__pycache__/road.cpython-38.pyc differ diff --git a/pytorch_grad_cam/metrics/cam_mult_image.py b/pytorch_grad_cam/metrics/cam_mult_image.py new file mode 100644 index 0000000..bd4bf8a --- /dev/null +++ b/pytorch_grad_cam/metrics/cam_mult_image.py @@ -0,0 +1,37 @@ +import torch +import numpy as np +from typing import List, Callable +from pytorch_grad_cam.metrics.perturbation_confidence import PerturbationConfidenceMetric + + +def multiply_tensor_with_cam(input_tensor: torch.Tensor, + cam: torch.Tensor): + """ Multiply an input tensor (after normalization) + with a pixel attribution map + """ + return input_tensor * cam + + +class CamMultImageConfidenceChange(PerturbationConfidenceMetric): + def __init__(self): + super(CamMultImageConfidenceChange, + self).__init__(multiply_tensor_with_cam) + + +class DropInConfidence(CamMultImageConfidenceChange): + def __init__(self): + super(DropInConfidence, self).__init__() + + def __call__(self, *args, **kwargs): + scores = super(DropInConfidence, self).__call__(*args, **kwargs) + scores = -scores + return np.maximum(scores, 0) + + +class IncreaseInConfidence(CamMultImageConfidenceChange): + def __init__(self): + super(IncreaseInConfidence, self).__init__() + + def __call__(self, *args, **kwargs): + scores = super(IncreaseInConfidence, self).__call__(*args, **kwargs) + return np.float32(scores > 0) diff --git a/pytorch_grad_cam/metrics/perturbation_confidence.py b/pytorch_grad_cam/metrics/perturbation_confidence.py new file mode 100644 index 0000000..813ffc7 --- /dev/null +++ b/pytorch_grad_cam/metrics/perturbation_confidence.py @@ -0,0 +1,109 @@ +import torch +import numpy as np +from typing import List, Callable + +import numpy as np +import cv2 + + +class PerturbationConfidenceMetric: + def __init__(self, perturbation): + self.perturbation = perturbation + + def __call__(self, input_tensor: torch.Tensor, + cams: np.ndarray, + targets: List[Callable], + model: torch.nn.Module, + return_visualization=False, + return_diff=True): + + if return_diff: + with torch.no_grad(): + outputs = model(input_tensor) + scores = [target(output).cpu().numpy() + for target, output in zip(targets, outputs)] + scores = np.float32(scores) + + batch_size = input_tensor.size(0) + perturbated_tensors = [] + for i in range(batch_size): + cam = cams[i] + tensor = self.perturbation(input_tensor[i, ...].cpu(), + torch.from_numpy(cam)) + tensor = tensor.to(input_tensor.device) + perturbated_tensors.append(tensor.unsqueeze(0)) + perturbated_tensors = torch.cat(perturbated_tensors) + + with torch.no_grad(): + outputs_after_imputation = model(perturbated_tensors) + scores_after_imputation = [ + target(output).cpu().numpy() for target, output in zip( + targets, outputs_after_imputation)] + scores_after_imputation = np.float32(scores_after_imputation) + + if return_diff: + result = scores_after_imputation - scores + else: + result = scores_after_imputation + + if return_visualization: + return result, perturbated_tensors + else: + return result + + +class RemoveMostRelevantFirst: + def __init__(self, percentile, imputer): + self.percentile = percentile + self.imputer = imputer + + def __call__(self, input_tensor, mask): + imputer = self.imputer + if self.percentile != 'auto': + threshold = np.percentile(mask.cpu().numpy(), self.percentile) + binary_mask = np.float32(mask < threshold) + else: + _, binary_mask = cv2.threshold( + np.uint8(mask * 255), 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) + + binary_mask = torch.from_numpy(binary_mask) + binary_mask = binary_mask.to(mask.device) + return imputer(input_tensor, binary_mask) + + +class RemoveLeastRelevantFirst(RemoveMostRelevantFirst): + def __init__(self, percentile, imputer): + super(RemoveLeastRelevantFirst, self).__init__(percentile, imputer) + + def __call__(self, input_tensor, mask): + return super(RemoveLeastRelevantFirst, self).__call__( + input_tensor, 1 - mask) + + +class AveragerAcrossThresholds: + def __init__( + self, + imputer, + percentiles=[ + 10, + 20, + 30, + 40, + 50, + 60, + 70, + 80, + 90]): + self.imputer = imputer + self.percentiles = percentiles + + def __call__(self, + input_tensor: torch.Tensor, + cams: np.ndarray, + targets: List[Callable], + model: torch.nn.Module): + scores = [] + for percentile in self.percentiles: + imputer = self.imputer(percentile) + scores.append(imputer(input_tensor, cams, targets, model)) + return np.mean(np.float32(scores), axis=0) diff --git a/pytorch_grad_cam/metrics/road.py b/pytorch_grad_cam/metrics/road.py new file mode 100644 index 0000000..7b09c4b --- /dev/null +++ b/pytorch_grad_cam/metrics/road.py @@ -0,0 +1,181 @@ +# A Consistent and Efficient Evaluation Strategy for Attribution Methods +# https://arxiv.org/abs/2202.00449 +# Taken from https://raw.githubusercontent.com/tleemann/road_evaluation/main/imputations.py +# MIT License + +# Copyright (c) 2022 Tobias Leemann + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + + +# Implementations of our imputation models. +import torch +import numpy as np +from scipy.sparse import lil_matrix, csc_matrix +from scipy.sparse.linalg import spsolve +from typing import List, Callable +from pytorch_grad_cam.metrics.perturbation_confidence import PerturbationConfidenceMetric, \ + AveragerAcrossThresholds, \ + RemoveMostRelevantFirst, \ + RemoveLeastRelevantFirst + +# The weights of the surrounding pixels +neighbors_weights = [((1, 1), 1 / 12), + ((0, 1), 1 / 6), + ((-1, 1), 1 / 12), + ((1, -1), 1 / 12), + ((0, -1), 1 / 6), + ((-1, -1), 1 / 12), + ((1, 0), 1 / 6), + ((-1, 0), 1 / 6)] + + +class NoisyLinearImputer: + def __init__(self, + noise: float = 0.01, + weighting: List[float] = neighbors_weights): + """ + Noisy linear imputation. + noise: magnitude of noise to add (absolute, set to 0 for no noise) + weighting: Weights of the neighboring pixels in the computation. + List of tuples of (offset, weight) + """ + self.noise = noise + self.weighting = neighbors_weights + + @staticmethod + def add_offset_to_indices(indices, offset, mask_shape): + """ Add the corresponding offset to the indices. + Return new indices plus a valid bit-vector. """ + cord1 = indices % mask_shape[1] + cord0 = indices // mask_shape[1] + cord0 += offset[0] + cord1 += offset[1] + valid = ((cord0 < 0) | (cord1 < 0) | + (cord0 >= mask_shape[0]) | + (cord1 >= mask_shape[1])) + return ~valid, indices + offset[0] * mask_shape[1] + offset[1] + + @staticmethod + def setup_sparse_system(mask, img, neighbors_weights): + """ Vectorized version to set up the equation system. + mask: (H, W)-tensor of missing pixels. + Image: (H, W, C)-tensor of all values. + Return (N,N)-System matrix, (N,C)-Right hand side for each of the C channels. + """ + maskflt = mask.flatten() + imgflat = img.reshape((img.shape[0], -1)) + # Indices that are imputed in the flattened mask: + indices = np.argwhere(maskflt == 0).flatten() + coords_to_vidx = np.zeros(len(maskflt), dtype=int) + coords_to_vidx[indices] = np.arange(len(indices)) + numEquations = len(indices) + # System matrix: + A = lil_matrix((numEquations, numEquations)) + b = np.zeros((numEquations, img.shape[0])) + # Sum of weights assigned: + sum_neighbors = np.ones(numEquations) + for n in neighbors_weights: + offset, weight = n[0], n[1] + # Take out outliers + valid, new_coords = NoisyLinearImputer.add_offset_to_indices( + indices, offset, mask.shape) + valid_coords = new_coords[valid] + valid_ids = np.argwhere(valid == 1).flatten() + # Add values to the right hand-side + has_values_coords = valid_coords[maskflt[valid_coords] > 0.5] + has_values_ids = valid_ids[maskflt[valid_coords] > 0.5] + b[has_values_ids, :] -= weight * imgflat[:, has_values_coords].T + # Add weights to the system (left hand side) +# Find coordinates in the system. + has_no_values = valid_coords[maskflt[valid_coords] < 0.5] + variable_ids = coords_to_vidx[has_no_values] + has_no_values_ids = valid_ids[maskflt[valid_coords] < 0.5] + A[has_no_values_ids, variable_ids] = weight + # Reduce weight for invalid + sum_neighbors[np.argwhere(valid == 0).flatten()] = \ + sum_neighbors[np.argwhere(valid == 0).flatten()] - weight + + A[np.arange(numEquations), np.arange(numEquations)] = -sum_neighbors + return A, b + + def __call__(self, img: torch.Tensor, mask: torch.Tensor): + """ Our linear inputation scheme. """ + """ + This is the function to do the linear infilling + img: original image (C,H,W)-tensor; + mask: mask; (H,W)-tensor + + """ + imgflt = img.reshape(img.shape[0], -1) + maskflt = mask.reshape(-1) + # Indices that need to be imputed. + indices_linear = np.argwhere(maskflt == 0).flatten() + # Set up sparse equation system, solve system. + A, b = NoisyLinearImputer.setup_sparse_system( + mask.numpy(), img.numpy(), neighbors_weights) + res = torch.tensor(spsolve(csc_matrix(A), b), dtype=torch.float) + + # Fill the values with the solution of the system. + img_infill = imgflt.clone() + img_infill[:, indices_linear] = res.t() + self.noise * \ + torch.randn_like(res.t()) + + return img_infill.reshape_as(img) + + +class ROADMostRelevantFirst(PerturbationConfidenceMetric): + def __init__(self, percentile=80): + super(ROADMostRelevantFirst, self).__init__( + RemoveMostRelevantFirst(percentile, NoisyLinearImputer())) + + +class ROADLeastRelevantFirst(PerturbationConfidenceMetric): + def __init__(self, percentile=20): + super(ROADLeastRelevantFirst, self).__init__( + RemoveLeastRelevantFirst(percentile, NoisyLinearImputer())) + + +class ROADMostRelevantFirstAverage(AveragerAcrossThresholds): + def __init__(self, percentiles=[10, 20, 30, 40, 50, 60, 70, 80, 90]): + super(ROADMostRelevantFirstAverage, self).__init__( + ROADMostRelevantFirst, percentiles) + + +class ROADLeastRelevantFirstAverage(AveragerAcrossThresholds): + def __init__(self, percentiles=[10, 20, 30, 40, 50, 60, 70, 80, 90]): + super(ROADLeastRelevantFirstAverage, self).__init__( + ROADLeastRelevantFirst, percentiles) + + +class ROADCombined: + def __init__(self, percentiles=[10, 20, 30, 40, 50, 60, 70, 80, 90]): + self.percentiles = percentiles + self.morf_averager = ROADMostRelevantFirstAverage(percentiles) + self.lerf_averager = ROADLeastRelevantFirstAverage(percentiles) + + def __call__(self, + input_tensor: torch.Tensor, + cams: np.ndarray, + targets: List[Callable], + model: torch.nn.Module): + + scores_lerf = self.lerf_averager(input_tensor, cams, targets, model) + scores_morf = self.morf_averager(input_tensor, cams, targets, model) + return (scores_lerf - scores_morf) / 2 diff --git a/pytorch_grad_cam/random_cam.py b/pytorch_grad_cam/random_cam.py new file mode 100644 index 0000000..5bb6ecc --- /dev/null +++ b/pytorch_grad_cam/random_cam.py @@ -0,0 +1,22 @@ +import numpy as np +from pytorch_grad_cam.base_cam import BaseCAM + + +class RandomCAM(BaseCAM): + def __init__(self, model, target_layers, use_cuda=False, + reshape_transform=None): + super( + RandomCAM, + self).__init__( + model, + target_layers, + use_cuda, + reshape_transform) + + def get_cam_weights(self, + input_tensor, + target_layer, + target_category, + activations, + grads): + return np.random.uniform(-1, 1, size=(grads.shape[0], grads.shape[1])) diff --git a/pytorch_grad_cam/score_cam.py b/pytorch_grad_cam/score_cam.py new file mode 100644 index 0000000..38460c5 --- /dev/null +++ b/pytorch_grad_cam/score_cam.py @@ -0,0 +1,60 @@ +import torch +import tqdm +from pytorch_grad_cam.base_cam import BaseCAM + + +class ScoreCAM(BaseCAM): + def __init__( + self, + model, + target_layers, + use_cuda=False, + reshape_transform=None): + super(ScoreCAM, self).__init__(model, + target_layers, + use_cuda, + reshape_transform=reshape_transform, + uses_gradients=False) + + def get_cam_weights(self, + input_tensor, + target_layer, + targets, + activations, + grads): + with torch.no_grad(): + upsample = torch.nn.UpsamplingBilinear2d( + size=input_tensor.shape[-2:]) + activation_tensor = torch.from_numpy(activations) + if self.cuda: + activation_tensor = activation_tensor.cuda() + + upsampled = upsample(activation_tensor) + + maxs = upsampled.view(upsampled.size(0), + upsampled.size(1), -1).max(dim=-1)[0] + mins = upsampled.view(upsampled.size(0), + upsampled.size(1), -1).min(dim=-1)[0] + + maxs, mins = maxs[:, :, None, None], mins[:, :, None, None] + upsampled = (upsampled - mins) / (maxs - mins) + + input_tensors = input_tensor[:, None, + :, :] * upsampled[:, :, None, :, :] + + if hasattr(self, "batch_size"): + BATCH_SIZE = self.batch_size + else: + BATCH_SIZE = 16 + + scores = [] + for target, tensor in zip(targets, input_tensors): + for i in tqdm.tqdm(range(0, tensor.size(0), BATCH_SIZE)): + batch = tensor[i: i + BATCH_SIZE, :] + outputs = [target(o).cpu().item() + for o in self.model(batch)] + scores.extend(outputs) + scores = torch.Tensor(scores) + scores = scores.view(activations.shape[0], activations.shape[1]) + weights = torch.nn.Softmax(dim=-1)(scores).numpy() + return weights diff --git a/pytorch_grad_cam/sobel_cam.py b/pytorch_grad_cam/sobel_cam.py new file mode 100644 index 0000000..84168a7 --- /dev/null +++ b/pytorch_grad_cam/sobel_cam.py @@ -0,0 +1,11 @@ +import cv2 + + +def sobel_cam(img): + gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) + grad_x = cv2.Sobel(gray, cv2.CV_64F, 1, 0, ksize=3) + grad_y = cv2.Sobel(gray, cv2.CV_64F, 0, 1, ksize=3) + abs_grad_x = cv2.convertScaleAbs(grad_x) + abs_grad_y = cv2.convertScaleAbs(grad_y) + grad = cv2.addWeighted(abs_grad_x, 0.5, abs_grad_y, 0.5, 0) + return grad diff --git a/pytorch_grad_cam/utils/__init__.py b/pytorch_grad_cam/utils/__init__.py new file mode 100644 index 0000000..269a526 --- /dev/null +++ b/pytorch_grad_cam/utils/__init__.py @@ -0,0 +1,4 @@ +from pytorch_grad_cam.utils.image import deprocess_image +from pytorch_grad_cam.utils.svd_on_activations import get_2d_projection +from pytorch_grad_cam.utils import model_targets +from pytorch_grad_cam.utils import reshape_transforms diff --git a/pytorch_grad_cam/utils/__pycache__/__init__.cpython-37.pyc b/pytorch_grad_cam/utils/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000..b2a360c Binary files /dev/null and b/pytorch_grad_cam/utils/__pycache__/__init__.cpython-37.pyc differ diff --git a/pytorch_grad_cam/utils/__pycache__/__init__.cpython-38.pyc b/pytorch_grad_cam/utils/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000..e08b3f3 Binary files /dev/null and b/pytorch_grad_cam/utils/__pycache__/__init__.cpython-38.pyc differ diff --git a/pytorch_grad_cam/utils/__pycache__/find_layers.cpython-37.pyc b/pytorch_grad_cam/utils/__pycache__/find_layers.cpython-37.pyc new file mode 100644 index 0000000..83bec3d Binary files /dev/null and b/pytorch_grad_cam/utils/__pycache__/find_layers.cpython-37.pyc differ diff --git a/pytorch_grad_cam/utils/__pycache__/find_layers.cpython-38.pyc b/pytorch_grad_cam/utils/__pycache__/find_layers.cpython-38.pyc new file mode 100644 index 0000000..aef9ea6 Binary files /dev/null and b/pytorch_grad_cam/utils/__pycache__/find_layers.cpython-38.pyc differ diff --git a/pytorch_grad_cam/utils/__pycache__/image.cpython-37.pyc b/pytorch_grad_cam/utils/__pycache__/image.cpython-37.pyc new file mode 100644 index 0000000..359400e Binary files /dev/null and b/pytorch_grad_cam/utils/__pycache__/image.cpython-37.pyc differ diff --git a/pytorch_grad_cam/utils/__pycache__/image.cpython-38.pyc b/pytorch_grad_cam/utils/__pycache__/image.cpython-38.pyc new file mode 100644 index 0000000..03b58c5 Binary files /dev/null and b/pytorch_grad_cam/utils/__pycache__/image.cpython-38.pyc differ diff --git a/pytorch_grad_cam/utils/__pycache__/model_targets.cpython-37.pyc b/pytorch_grad_cam/utils/__pycache__/model_targets.cpython-37.pyc new file mode 100644 index 0000000..00a453e Binary files /dev/null and b/pytorch_grad_cam/utils/__pycache__/model_targets.cpython-37.pyc differ diff --git a/pytorch_grad_cam/utils/__pycache__/model_targets.cpython-38.pyc b/pytorch_grad_cam/utils/__pycache__/model_targets.cpython-38.pyc new file mode 100644 index 0000000..eb0ecd6 Binary files /dev/null and b/pytorch_grad_cam/utils/__pycache__/model_targets.cpython-38.pyc differ diff --git a/pytorch_grad_cam/utils/__pycache__/reshape_transforms.cpython-37.pyc b/pytorch_grad_cam/utils/__pycache__/reshape_transforms.cpython-37.pyc new file mode 100644 index 0000000..ddbc899 Binary files /dev/null and b/pytorch_grad_cam/utils/__pycache__/reshape_transforms.cpython-37.pyc differ diff --git a/pytorch_grad_cam/utils/__pycache__/reshape_transforms.cpython-38.pyc b/pytorch_grad_cam/utils/__pycache__/reshape_transforms.cpython-38.pyc new file mode 100644 index 0000000..24f0bed Binary files /dev/null and b/pytorch_grad_cam/utils/__pycache__/reshape_transforms.cpython-38.pyc differ diff --git a/pytorch_grad_cam/utils/__pycache__/svd_on_activations.cpython-37.pyc b/pytorch_grad_cam/utils/__pycache__/svd_on_activations.cpython-37.pyc new file mode 100644 index 0000000..379b63d Binary files /dev/null and b/pytorch_grad_cam/utils/__pycache__/svd_on_activations.cpython-37.pyc differ diff --git a/pytorch_grad_cam/utils/__pycache__/svd_on_activations.cpython-38.pyc b/pytorch_grad_cam/utils/__pycache__/svd_on_activations.cpython-38.pyc new file mode 100644 index 0000000..66066cb Binary files /dev/null and b/pytorch_grad_cam/utils/__pycache__/svd_on_activations.cpython-38.pyc differ diff --git a/pytorch_grad_cam/utils/find_layers.py b/pytorch_grad_cam/utils/find_layers.py new file mode 100644 index 0000000..4b9e445 --- /dev/null +++ b/pytorch_grad_cam/utils/find_layers.py @@ -0,0 +1,30 @@ +def replace_layer_recursive(model, old_layer, new_layer): + for name, layer in model._modules.items(): + if layer == old_layer: + model._modules[name] = new_layer + return True + elif replace_layer_recursive(layer, old_layer, new_layer): + return True + return False + + +def replace_all_layer_type_recursive(model, old_layer_type, new_layer): + for name, layer in model._modules.items(): + if isinstance(layer, old_layer_type): + model._modules[name] = new_layer + replace_all_layer_type_recursive(layer, old_layer_type, new_layer) + + +def find_layer_types_recursive(model, layer_types): + def predicate(layer): + return type(layer) in layer_types + return find_layer_predicate_recursive(model, predicate) + + +def find_layer_predicate_recursive(model, predicate): + result = [] + for name, layer in model._modules.items(): + if predicate(layer): + result.append(layer) + result.extend(find_layer_predicate_recursive(layer, predicate)) + return result diff --git a/pytorch_grad_cam/utils/image.py b/pytorch_grad_cam/utils/image.py new file mode 100644 index 0000000..34d92ba --- /dev/null +++ b/pytorch_grad_cam/utils/image.py @@ -0,0 +1,183 @@ +import matplotlib +from matplotlib import pyplot as plt +from matplotlib.lines import Line2D +import cv2 +import numpy as np +import torch +from torchvision.transforms import Compose, Normalize, ToTensor +from typing import List, Dict +import math + + +def preprocess_image( + img: np.ndarray, mean=[ + 0.5, 0.5, 0.5], std=[ + 0.5, 0.5, 0.5]) -> torch.Tensor: + preprocessing = Compose([ + ToTensor(), + Normalize(mean=mean, std=std) + ]) + return preprocessing(img.copy()).unsqueeze(0) + + +def deprocess_image(img): + """ see https://github.com/jacobgil/keras-grad-cam/blob/master/grad-cam.py#L65 """ + img = img - np.mean(img) + img = img / (np.std(img) + 1e-5) + img = img * 0.1 + img = img + 0.5 + img = np.clip(img, 0, 1) + return np.uint8(img * 255) + + +def show_cam_on_image(img: np.ndarray, + mask: np.ndarray, + use_rgb: bool = False, + colormap: int = cv2.COLORMAP_JET, + image_weight: float = 0.5) -> np.ndarray: + """ This function overlays the cam mask on the image as an heatmap. + By default the heatmap is in BGR format. + + :param img: The base image in RGB or BGR format. + :param mask: The cam mask. + :param use_rgb: Whether to use an RGB or BGR heatmap, this should be set to True if 'img' is in RGB format. + :param colormap: The OpenCV colormap to be used. + :param image_weight: The final result is image_weight * img + (1-image_weight) * mask. + :returns: The default image with the cam overlay. + """ + heatmap = cv2.applyColorMap(np.uint8(255 * mask), colormap) + if use_rgb: + heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB) + heatmap = np.float32(heatmap) / 255 + + if np.max(img) > 1: + raise Exception( + "The input image should np.float32 in the range [0, 1]") + + if image_weight < 0 or image_weight > 1: + raise Exception( + f"image_weight should be in the range [0, 1].\ + Got: {image_weight}") + + cam = (1 - image_weight) * heatmap + image_weight * img + cam = cam / np.max(cam) + return np.uint8(255 * cam) + + +def create_labels_legend(concept_scores: np.ndarray, + labels: Dict[int, str], + top_k=2): + concept_categories = np.argsort(concept_scores, axis=1)[:, ::-1][:, :top_k] + concept_labels_topk = [] + for concept_index in range(concept_categories.shape[0]): + categories = concept_categories[concept_index, :] + concept_labels = [] + for category in categories: + score = concept_scores[concept_index, category] + label = f"{','.join(labels[category].split(',')[:3])}:{score:.2f}" + concept_labels.append(label) + concept_labels_topk.append("\n".join(concept_labels)) + return concept_labels_topk + + +def show_factorization_on_image(img: np.ndarray, + explanations: np.ndarray, + colors: List[np.ndarray] = None, + image_weight: float = 0.5, + concept_labels: List = None) -> np.ndarray: + """ Color code the different component heatmaps on top of the image. + Every component color code will be magnified according to the heatmap itensity + (by modifying the V channel in the HSV color space), + and optionally create a lagend that shows the labels. + + Since different factorization component heatmaps can overlap in principle, + we need a strategy to decide how to deal with the overlaps. + This keeps the component that has a higher value in it's heatmap. + + :param img: The base image RGB format. + :param explanations: A tensor of shape num_componetns x height x width, with the component visualizations. + :param colors: List of R, G, B colors to be used for the components. + If None, will use the gist_rainbow cmap as a default. + :param image_weight: The final result is image_weight * img + (1-image_weight) * visualization. + :concept_labels: A list of strings for every component. If this is paseed, a legend that shows + the labels and their colors will be added to the image. + :returns: The visualized image. + """ + n_components = explanations.shape[0] + if colors is None: + # taken from https://github.com/edocollins/DFF/blob/master/utils.py + _cmap = plt.cm.get_cmap('gist_rainbow') + colors = [ + np.array( + _cmap(i)) for i in np.arange( + 0, + 1, + 1.0 / + n_components)] + concept_per_pixel = explanations.argmax(axis=0) + masks = [] + for i in range(n_components): + mask = np.zeros(shape=(img.shape[0], img.shape[1], 3)) + mask[:, :, :] = colors[i][:3] + explanation = explanations[i] + explanation[concept_per_pixel != i] = 0 + mask = np.uint8(mask * 255) + mask = cv2.cvtColor(mask, cv2.COLOR_RGB2HSV) + mask[:, :, 2] = np.uint8(255 * explanation) + mask = cv2.cvtColor(mask, cv2.COLOR_HSV2RGB) + mask = np.float32(mask) / 255 + masks.append(mask) + + mask = np.sum(np.float32(masks), axis=0) + result = img * image_weight + mask * (1 - image_weight) + result = np.uint8(result * 255) + + if concept_labels is not None: + px = 1 / plt.rcParams['figure.dpi'] # pixel in inches + fig = plt.figure(figsize=(result.shape[1] * px, result.shape[0] * px)) + plt.rcParams['legend.fontsize'] = int( + 14 * result.shape[0] / 256 / max(1, n_components / 6)) + lw = 5 * result.shape[0] / 256 + lines = [Line2D([0], [0], color=colors[i], lw=lw) + for i in range(n_components)] + plt.legend(lines, + concept_labels, + mode="expand", + fancybox=True, + shadow=True) + + plt.tight_layout(pad=0, w_pad=0, h_pad=0) + plt.axis('off') + fig.canvas.draw() + data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) + plt.close(fig=fig) + data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) + data = cv2.resize(data, (result.shape[1], result.shape[0])) + result = np.hstack((result, data)) + return result + + +def scale_cam_image(cam, target_size=None): + result = [] + for img in cam: + img = img - np.min(img) + img = img / (1e-7 + np.max(img)) + if target_size is not None: + img = cv2.resize(img, target_size) + result.append(img) + result = np.float32(result) + + return result + + +def scale_accross_batch_and_channels(tensor, target_size): + batch_size, channel_size = tensor.shape[:2] + reshaped_tensor = tensor.reshape( + batch_size * channel_size, *tensor.shape[2:]) + result = scale_cam_image(reshaped_tensor, target_size) + result = result.reshape( + batch_size, + channel_size, + target_size[1], + target_size[0]) + return result diff --git a/pytorch_grad_cam/utils/model_targets.py b/pytorch_grad_cam/utils/model_targets.py new file mode 100644 index 0000000..489dd19 --- /dev/null +++ b/pytorch_grad_cam/utils/model_targets.py @@ -0,0 +1,103 @@ +import numpy as np +import torch +import torchvision + + +class ClassifierOutputTarget: + def __init__(self, category): + self.category = category + + def __call__(self, model_output): + if len(model_output.shape) == 1: + return model_output[self.category] + return model_output[:, self.category] + + +class ClassifierOutputSoftmaxTarget: + def __init__(self, category): + self.category = category + + def __call__(self, model_output): + if len(model_output.shape) == 1: + return torch.softmax(model_output, dim=-1)[self.category] + return torch.softmax(model_output, dim=-1)[:, self.category] + + +class BinaryClassifierOutputTarget: + def __init__(self, category): + self.category = category + + def __call__(self, model_output): + if self.category == 1: + sign = 1 + else: + sign = -1 + return model_output * sign + + +class SoftmaxOutputTarget: + def __init__(self): + pass + + def __call__(self, model_output): + return torch.softmax(model_output, dim=-1) + + +class RawScoresOutputTarget: + def __init__(self): + pass + + def __call__(self, model_output): + return model_output + + +class SemanticSegmentationTarget: + """ Gets a binary spatial mask and a category, + And return the sum of the category scores, + of the pixels in the mask. """ + + def __init__(self, category, mask): + self.category = category + self.mask = torch.from_numpy(mask) + if torch.cuda.is_available(): + self.mask = self.mask.cuda() + + def __call__(self, model_output): + return (model_output[self.category, :, :] * self.mask).sum() + + +class FasterRCNNBoxScoreTarget: + """ For every original detected bounding box specified in "bounding boxes", + assign a score on how the current bounding boxes match it, + 1. In IOU + 2. In the classification score. + If there is not a large enough overlap, or the category changed, + assign a score of 0. + + The total score is the sum of all the box scores. + """ + + def __init__(self, labels, bounding_boxes, iou_threshold=0.5): + self.labels = labels + self.bounding_boxes = bounding_boxes + self.iou_threshold = iou_threshold + + def __call__(self, model_outputs): + output = torch.Tensor([0]) + if torch.cuda.is_available(): + output = output.cuda() + + if len(model_outputs["boxes"]) == 0: + return output + + for box, label in zip(self.bounding_boxes, self.labels): + box = torch.Tensor(box[None, :]) + if torch.cuda.is_available(): + box = box.cuda() + + ious = torchvision.ops.box_iou(box, model_outputs["boxes"]) + index = ious.argmax() + if ious[0, index] > self.iou_threshold and model_outputs["labels"][index] == label: + score = ious[0, index] + model_outputs["scores"][index] + output = output + score + return output diff --git a/pytorch_grad_cam/utils/reshape_transforms.py b/pytorch_grad_cam/utils/reshape_transforms.py new file mode 100644 index 0000000..509f092 --- /dev/null +++ b/pytorch_grad_cam/utils/reshape_transforms.py @@ -0,0 +1,34 @@ +import torch + + +def fasterrcnn_reshape_transform(x): + target_size = x['pool'].size()[-2:] + activations = [] + for key, value in x.items(): + activations.append( + torch.nn.functional.interpolate( + torch.abs(value), + target_size, + mode='bilinear')) + activations = torch.cat(activations, axis=1) + return activations + + +def swinT_reshape_transform(tensor, height=7, width=7): + result = tensor.reshape(tensor.size(0), + height, width, tensor.size(2)) + + # Bring the channels to the first dimension, + # like in CNNs. + result = result.transpose(2, 3).transpose(1, 2) + return result + + +def vit_reshape_transform(tensor, height=14, width=14): + result = tensor[:, 1:, :].reshape(tensor.size(0), + height, width, tensor.size(2)) + + # Bring the channels to the first dimension, + # like in CNNs. + result = result.transpose(2, 3).transpose(1, 2) + return result diff --git a/pytorch_grad_cam/utils/svd_on_activations.py b/pytorch_grad_cam/utils/svd_on_activations.py new file mode 100644 index 0000000..a406aee --- /dev/null +++ b/pytorch_grad_cam/utils/svd_on_activations.py @@ -0,0 +1,19 @@ +import numpy as np + + +def get_2d_projection(activation_batch): + # TBD: use pytorch batch svd implementation + activation_batch[np.isnan(activation_batch)] = 0 + projections = [] + for activations in activation_batch: + reshaped_activations = (activations).reshape( + activations.shape[0], -1).transpose() + # Centering before the SVD seems to be important here, + # Otherwise the image returned is negative + reshaped_activations = reshaped_activations - \ + reshaped_activations.mean(axis=0) + U, S, VT = np.linalg.svd(reshaped_activations, full_matrices=True) + projection = reshaped_activations @ VT[0, :] + projection = projection.reshape(activations.shape[1:]) + projections.append(projection) + return np.float32(projections) diff --git a/pytorch_grad_cam/xgrad_cam.py b/pytorch_grad_cam/xgrad_cam.py new file mode 100644 index 0000000..81a920f --- /dev/null +++ b/pytorch_grad_cam/xgrad_cam.py @@ -0,0 +1,31 @@ +import numpy as np +from pytorch_grad_cam.base_cam import BaseCAM + + +class XGradCAM(BaseCAM): + def __init__( + self, + model, + target_layers, + use_cuda=False, + reshape_transform=None): + super( + XGradCAM, + self).__init__( + model, + target_layers, + use_cuda, + reshape_transform) + + def get_cam_weights(self, + input_tensor, + target_layer, + target_category, + activations, + grads): + sum_activations = np.sum(activations, axis=(2, 3)) + eps = 1e-7 + weights = grads * activations / \ + (sum_activations[:, :, None, None] + eps) + weights = weights.sum(axis=(2, 3)) + return weights diff --git a/restful_main.py b/restful_main.py new file mode 100644 index 0000000..4d6fca2 --- /dev/null +++ b/restful_main.py @@ -0,0 +1,52 @@ +from flask import Flask, jsonify +from flask_cors import CORS +from flask_restful import Api, Resource, reqparse +from utils import get_log +from task.image_interpretability import ImageInterpretability +import ast + +app = Flask(__name__, static_folder='', static_url_path='') +app.config['UPLOAD_FOLDER'] = "./upload_data" +cors = CORS(app, resources={r"*": {"origins": "*"}}) +api = Api(app) + + +class Interpretability2Image(Resource): + + def post(self): + parser = reqparse.RequestParser() + parser.add_argument('image_name', type=str, required=True, default='', help='') + parser.add_argument('image_path', type=str, required=True, default='', help='') + parser.add_argument('Interpretability_method', type=str, required=True, default='textfooler', help='') + parser.add_argument('model_info', type=str, required=True, default={}, help='') + parser.add_argument('output_path', type=str, required=True, default='', help='') + parser.add_argument('kwargs', type=str, required=True, default={}, help='') + args = parser.parse_args() + print(args) + + args.model_info = ast.literal_eval(args.model_info) + args.dataset_info = ast.literal_eval(args.dataset_info) + + Interpretability = ImageInterpretability() + rst = Interpretability.perform( + image_name=args.model_info.image_name, + image_path=args.model_info.image_path, + Interpretability_method=args.Interpretability_method, + model_name=args.model_info.model_name, + output_path=args.output_path, + **args.kwargs + ) + return jsonify(rst) + + def get(self): + msg = get_log(log_path=Interpretability2Image.LOG_PATH) + if msg: + return jsonify({'status': 1, 'log': msg}) + else: + return jsonify({'status': 0, 'log': None}) + + +api.add_resource(Interpretability2Image, '/Interpretability2Image') + +if __name__ == '__main__': + app.run(host='0.0.0.0', port=5002, debug=True) \ No newline at end of file diff --git a/sample/ILSVRC2012_val_00000001.JPEG b/sample/ILSVRC2012_val_00000001.JPEG new file mode 100644 index 0000000..fd3a93f Binary files /dev/null and b/sample/ILSVRC2012_val_00000001.JPEG differ diff --git a/sample/ILSVRC2012_val_00000002.JPEG b/sample/ILSVRC2012_val_00000002.JPEG new file mode 100644 index 0000000..543f639 Binary files /dev/null and b/sample/ILSVRC2012_val_00000002.JPEG differ diff --git a/sample/ILSVRC2012_val_00000003.JPEG b/sample/ILSVRC2012_val_00000003.JPEG new file mode 100644 index 0000000..b32c5ed Binary files /dev/null and b/sample/ILSVRC2012_val_00000003.JPEG differ diff --git a/sample/ILSVRC2012_val_00000004.JPEG b/sample/ILSVRC2012_val_00000004.JPEG new file mode 100644 index 0000000..182189a Binary files /dev/null and b/sample/ILSVRC2012_val_00000004.JPEG differ diff --git a/sample/ILSVRC2012_val_00000005.JPEG b/sample/ILSVRC2012_val_00000005.JPEG new file mode 100644 index 0000000..a68ef17 Binary files /dev/null and b/sample/ILSVRC2012_val_00000005.JPEG differ diff --git a/sample/ILSVRC2012_val_00000006.JPEG b/sample/ILSVRC2012_val_00000006.JPEG new file mode 100644 index 0000000..f284522 Binary files /dev/null and b/sample/ILSVRC2012_val_00000006.JPEG differ diff --git a/sample/ILSVRC2012_val_00000007.JPEG b/sample/ILSVRC2012_val_00000007.JPEG new file mode 100644 index 0000000..be25a8c Binary files /dev/null and b/sample/ILSVRC2012_val_00000007.JPEG differ diff --git a/sample/ILSVRC2012_val_00000008.JPEG b/sample/ILSVRC2012_val_00000008.JPEG new file mode 100644 index 0000000..6b8ba97 Binary files /dev/null and b/sample/ILSVRC2012_val_00000008.JPEG differ diff --git a/sample/ILSVRC2012_val_00000009.JPEG b/sample/ILSVRC2012_val_00000009.JPEG new file mode 100644 index 0000000..9f6c9c0 Binary files /dev/null and b/sample/ILSVRC2012_val_00000009.JPEG differ diff --git a/sample/both.png b/sample/both.png new file mode 100644 index 0000000..61bdcae Binary files /dev/null and b/sample/both.png differ diff --git a/sample/horses.jpg b/sample/horses.jpg new file mode 100644 index 0000000..a82d6a6 Binary files /dev/null and b/sample/horses.jpg differ diff --git a/sample/img.png b/sample/img.png new file mode 100644 index 0000000..28db411 Binary files /dev/null and b/sample/img.png differ diff --git a/sample/img_1.png b/sample/img_1.png new file mode 100644 index 0000000..6e3eeb0 Binary files /dev/null and b/sample/img_1.png differ diff --git a/sample/img_2.png b/sample/img_2.png new file mode 100644 index 0000000..f15a71b Binary files /dev/null and b/sample/img_2.png differ diff --git a/task/__pycache__/image_interpretability.cpython-38.pyc b/task/__pycache__/image_interpretability.cpython-38.pyc new file mode 100644 index 0000000..1be1e48 Binary files /dev/null and b/task/__pycache__/image_interpretability.cpython-38.pyc differ diff --git a/task/image_interpretability.py b/task/image_interpretability.py new file mode 100644 index 0000000..4e9b297 --- /dev/null +++ b/task/image_interpretability.py @@ -0,0 +1,238 @@ +import argparse +import cv2 +import numpy as np +import torch +import json +from torchvision import models +from importlib import import_module +from utils import LogHelper +from pytorch_grad_cam import GradCAM, \ + HiResCAM, \ + ScoreCAM, \ + GradCAMPlusPlus, \ + AblationCAM, \ + XGradCAM, \ + EigenCAM, \ + EigenGradCAM, \ + LayerCAM, \ + FullGrad, \ + GradCAMElementWise +from pytorch_grad_cam.utils.find_layers import find_layer_types_recursive +from pytorch_grad_cam import GuidedBackpropReLUModel +from pytorch_grad_cam.utils.image import show_cam_on_image, \ + deprocess_image, \ + preprocess_image +from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget +from PIL import Image +from torchvision import transforms +import urllib.request +import json + + +# Load labels +with open("imagenet_1000.json") as f: + labels = json.load(f) + + + +available_device = 'cuda' +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--use-cuda', action='store_true', default=True, + help='Use NVIDIA GPU acceleration') + parser.add_argument('--aug_smooth', action='store_true', + help='Apply test time augmentation to smooth the CAM') + parser.add_argument( + '--eigen_smooth', + action='store_true', + help='Reduce noise by taking the first principle componenet' + 'of cam_weights*activations') + args = parser.parse_args() + args.use_cuda = args.use_cuda and torch.cuda.is_available() + if args.use_cuda: + print('Using GPU for acceleration') + else: + print('Using CPU for computation') + + return args + + +class ImageInterpretability(): + def __init__(self): + super().__init__() + + + def perform(self,image_path: str, method: str, model_info: dict,output_path: str,log_path=None,**kwargs): + # aug_smooth, eigen_smooth, + args = get_args() + ''' + 图片输入地址 + :param image_path: (type=str, required=True) value= (图片地址) + + 可解释性算法 + :param method: (type=str, required=True) value=['gradcam', 'hirescam', 'scorecam', 'gradcam++', + 'ablationcam', 'xgradcam', 'eigencam', 'eigengradcam', 'layercam', 'gradcamelementwise','fullgrad'] (可解释性方法) + + :param model_name: (type=str, required=True) value=[resnet, vgg, densenet, mnasnet] (模型名称) + + :param model: (type=nn.Module, required=True) value=[resnet18, resnet50, vgg11, vgg13, vgg16, vgg19, densenet161, mnasnet1_0] (模型) + + kwargs:非必要传入参数,在特定要求下传入 + { + aug_smooth:默认采用数据增强技术来改善cam质量 + (type=bool) value=[Ture, False] + + eigen_smooth:计算CAM(类激活映射)权重和激活之间的矩阵乘积,然后提取该结果的第一个主成分来减少结果中的噪音。 + (type=bool) value=[Ture, False] + } + + + ''' + methods = \ + {"gradcam": GradCAM, + "hirescam": HiResCAM, + "scorecam": ScoreCAM, + "gradcam++": GradCAMPlusPlus, + "ablationcam": AblationCAM, + "xgradcam": XGradCAM, + "eigencam": EigenCAM, + "eigengradcam": EigenGradCAM, + "layercam": LayerCAM, + "fullgrad": FullGrad, + "gradcamelementwise": GradCAMElementWise} + + # model = models.resnet50(pretrained=True) + # model_class = getattr(models,model_info.model_name) + # model = model_class(pretrained=True) + model = self.get_model(info=model_info, device=available_device) + # attack_log = LogHelper(log_path=log_path, root_log_name='aitest').build_new_log() + + if 'resnet' in model_info.get('model_name').lower(): + target_layers = [model.layer4] + elif 'vgg' in model_info.get('model_name').lower(): + target_layers = [model.features[-1]] + elif 'densenet' in model_info.get('model_name').lower(): + target_layers = [model.features[-1]] + elif 'mnasnet' in model_info.get('model_name').lower(): + target_layers = [model.layers[-1]] + else: + target_layers = find_layer_types_recursive(model, [torch.nn.ReLU]) + + + rgb_img = cv2.imread(image_path, 1)[:, :, ::-1] + rgb_img = np.float32(rgb_img) / 255 + input_tensor = preprocess_image(rgb_img, + mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]) + + + targets = None + + cam_algorithm = methods[method] + with cam_algorithm(model=model, + target_layers=target_layers, + use_cuda=args.use_cuda) as cam: + + # AblationCAM and ScoreCAM have batched implementations. + # You can override the internal batch size for faster computation. + cam.batch_size = 32 + aug_smooth = kwargs['aug_smooth'] + eigen_smooth = kwargs['eigen_smooth'] + grayscale_cam = cam(input_tensor=input_tensor, + targets=targets, + aug_smooth=aug_smooth, + eigen_smooth=eigen_smooth) + + # Here grayscale_cam has only one image in the batch + grayscale_cam = grayscale_cam[0, :] + + cam_image = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True) + + # cam_image is RGB encoded whereas "cv2.imwrite" requires BGR encoding. + cam_image = cv2.cvtColor(cam_image, cv2.COLOR_RGB2BGR) + + gb_model = GuidedBackpropReLUModel(model=model, use_cuda=args.use_cuda) + gb = gb_model(input_tensor, target_category=None) + + cam_mask = cv2.merge([grayscale_cam, grayscale_cam, grayscale_cam]) + cam_gb = deprocess_image(cam_mask * gb) + gb = deprocess_image(gb) + + model.eval() + image = Image.open(image_path) + transform = transforms.Compose([ + transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), + transforms.Normalize( + mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225] + ) + ]) + image = transform(image) + # Perform forward pass + with torch.no_grad(): + output = model(image.unsqueeze(0)) + + # Convert output to probabilities + probabilities = torch.softmax(output[0], dim=0) + # Get predicted label and probability + predicted_class = torch.argmax(probabilities).item() + probability = probabilities[predicted_class].item() + + # Get predicted label name + predicted_class = torch.argmax(probabilities).item() + label_name = labels[str(predicted_class)] + #print(f"Predicted label: {label_name}") + + #imwrite不支持带有中文路径的地址 + cv2.imwrite(output_path+label_name+'-'+method+"_cam3.jpg", cam_image) + cv2.imwrite(output_path+label_name+'-'+method+"_gb.jpg", gb) + cv2.imwrite(output_path+label_name+'-'+method+"_cam_gb.jpg",cam_gb) + + return {'output_path':output_path,'Predicted_class': label_name, 'Probability': probability} + + @staticmethod + def get_model(info, device): + # if isinstance(info, dict): + # info = argparse.Namespace(**info) + # + # if isinstance(info.path, dict): + # if isinstance(info.path, str): + # info.path = json.loads(info.path) + + # load pytorch model from user upload files + # if info.ownership != 'aitest' and ('upload' in info.type and 'pytorch' in info.type): + # return load_pytorch_model(state_dict_path=info.path['parameter_file'], device=device, net_file_path=info.path['structure_file']) + + # load built_in text model from textattack + if 'torchvision' == info.get('source'): + model_class = getattr(models, info.get('model_name')) + return model_class(pretrained=True) + + + +def load_pytorch_model(state_dict_path, model_class_name, device, net_file_path=None): + """ + model = load_user_model('word_cnn_for_classification', + 'C:\\Users\\pcl\\.cache\\textattack\\models\\classification\\cnn\\rotten-tomatoes\\pytorch_model.bin', + 'WordCNNForClassification') + """ + if net_file_path: + net_file_path = str(net_file_path).replace('\\', '/') + net_file_path = net_file_path.replace('./', '').replace('/', '.') + model_class = getattr(import_module(net_file_path), model_class_name) + model = model_class() + model.load(state_dict_path, device) + # tokenizer = model.tokenizer + # model = textattack.models.wrappers.PyTorchModelWrapper( + # model, tokenizer + # ) + return model + else: + return torch.load(state_dict_path, device) + + + + + diff --git a/test_image_interpretability.py b/test_image_interpretability.py new file mode 100644 index 0000000..7a7ea7c --- /dev/null +++ b/test_image_interpretability.py @@ -0,0 +1,27 @@ +from task.image_interpretability import ImageInterpretability + + +# # TODO: 使用gradcam算法,针对resnet模型,利用本地图片both.png进行可解释性分析,参数为默认参数 +# files_path = ImageInterpretability().perform(image_path='sample/both.png',method='gradcam', model_name='resnet',output_path='D:\桌面\image_interprebility\image_interprebility\image_interprebility/test') +# print(files_path) +# +kwargs={ + "aug_smooth":True, + "eigen_smooth":True + } +model_info ={"model_name":'resnet50', + "source": "torchvision" + } +# TODO: 使用fullgrad算法,针对resnet模型,利用本地图片both.png进行可解释性分析,参数为默认参数 +files_path = ImageInterpretability().perform(image_path='sample/img_2.png',method='fullgrad',model_info=model_info,output_path="D:/test_image/",**kwargs) +print(files_path) + + +# kwargs={ +# "target_layer":'', +# "aug_smooth":True, +# "eigen_smooth":True +# } +# # TODO: 使用fullgrad算法,针对resnet模型,利用本地图片both.png进行可解释性分析,采用数据增强技术来改善cam质量,并使用提取主成分的方式减少噪声,通过改变target_layer得到不同层的结果 +# files_path = ImageInterpretability().perform(image_path='sample/ILSVRC2012_val_00000002.JPEG',method='gradcam', model_name='resnet',**kwargs) +# print(files_path) diff --git a/utils.py b/utils.py new file mode 100644 index 0000000..7029047 --- /dev/null +++ b/utils.py @@ -0,0 +1,88 @@ +import logging +import os +import torch + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +# aitest data +MODEL_SAVE_PATH = 'data/aitest/models' +ADV_EXAMPLES_SAVE_PATH = 'data/aitest/adv_examples' +DATASET_SAVE_PATH = 'data/aitest/datasets' +FILES_SAVE_PATH = 'data/aitest/files' + + +# text_attack configuration +USE_ENCODING_PATH = "data/aitest/models/tfhub_modules/063d866c06683311b44b4992fd46003be952409c" +TEXTATTACK_CACHE_DIR = 'data/aitest/models/textattack' + +# text robust analysis +GPT2_PATH = 'data/aitest/models/gpt-2' +NLTK_PATH = 'data/aitest/files/nltk_data' + + +class LogHelper: + + LOG_STRING = "\033[34;1mAItest\033[0m" + + def __init__(self, log_path, root_log_name=None): + self.log_path = log_path + + if root_log_name: + self.log_name = f"{root_log_name}.Interpretability_Image" + self.logger = logging.getLogger(f"{root_log_name}.Interpretability_Image") + else: + self.log_name = 'aitest_Interpretability_Image' + self.logger = logging.getLogger('aitest_Interpretability_Image') + formatter = logging.Formatter(f"{LogHelper.LOG_STRING}: %(message)s") + stream_handler = logging.StreamHandler('') + stream_handler.setFormatter(formatter) + self.logger.addHandler(stream_handler) + + log_formatter = logging.Formatter("%(asctime)s : %(message)s") + log_filter = logging.Filter(name=self.log_name) + log_handler = logging.FileHandler(self.log_path) + log_handler.setLevel(level=logging.INFO) + log_handler.setFormatter(log_formatter) + log_handler.addFilter(log_filter) + self.fh = log_handler + + def build_new_log(self): + self.logger.addHandler(self.fh) + return self + + def insert_log(self, content=None): + if content: + self.logger.info(content) + + def finish_log(self): + self.logger.removeHandler(self.fh) + + +def get_log(log_path): + if os.path.exists(log_path): + with open(log_path, 'r+', encoding='utf-8') as f: + texts = f.read() + f.truncate(0) + return texts + return None + + +def get_size(path, show_bytes=True): + size = 0 + if os.path.isfile(path): + size = os.path.getsize(path) + elif os.path.isdir(path): + + for root, dirs, files in os.walk(path): + for file in files: + size += os.path.getsize(os.path.join(root, file)) + if show_bytes: + return int(size) + + for x in ['bytes', 'KB', 'MB', 'GB', 'TB']: + if size < 1024.0: + return "%3.1f %s" % (size, x) + size /= 1024.0 + return size + +