본문 바로가기
DL

[Grad Cam] gradcam pytorch 구현 with chatGPT

by apsdfjoi 2023. 6. 14.
728x90
반응형

chatGPT를 사용하여 Grad Cam을 구현해 봤다. 물론 한 번에 구현되진 않는다.

마지막 컨볼루션 레이어(layer4)의 출력을 사용한 결과

Grad Cam은 마지막 컨볼루션 레이어를 사용해야 의미있는 결과를 얻을 수 있다.

처음 컨볼루션 레이어(layer1)의 출력을 사용한 결과

코드는 다음과 같다.

import torch
import torch.nn as nn
import torchvision.transforms as transforms
import PIL.Image as Image
import cv2
import matplotlib.pyplot as plt
import numpy as np
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'

class GradCam:
    def __init__(self, model, target_layer):
        self.model = model
        self.target_layer = target_layer
        self.gradient = None
        self.target_activations = None
       
        self.model.eval()
        self.register_hooks()
       
    def register_hooks(self):
        def forward_hook(module, input, output):
            self.target_activations = output.detach()
           
        def backward_hook(module, grad_input, grad_output):
            self.gradient = grad_output[0]
       
        target_layer = self.get_target_layer()
        target_layer.register_forward_hook(forward_hook)
        target_layer.register_full_backward_hook(backward_hook)
       
    def get_target_layer(self):
        if isinstance(self.target_layer, str):
            modules = dict([*self.model.named_modules()])
            print('layer : str')
            return modules[self.target_layer]
        elif isinstance(self.target_layer, int):
            children = [*self.model.children()]
            return children[self.target_layer]
        elif isinstance(self.target_layer, nn.Sequential):
            print('layer : nn.sequentail')
            return self.target_layer[-1]
        else:
            raise ValueError("Invalid target layer value. It should be either the layer name (str), index (int), or nn.Sequential.")
       
    def forward_pass(self, input_tensor):
        return self.model(input_tensor)
   
    def backward_pass(self, output):
        one_hot = torch.zeros_like(output, dtype=torch.float)
        one_hot[0][torch.argmax(output)] = 1.0
       
        output.backward(gradient=one_hot)
        print(f'debug : class :{[torch.argmax(output)]}')
       
    def generate_gradcam(self, input_tensor):
        self.model.zero_grad()
        output = self.forward_pass(input_tensor)
        self.backward_pass(output)

        gradients = self.gradient.cpu().numpy()
        target_activations = self.target_activations.cpu().numpy()

        weights = np.mean(gradients, axis=(2, 3), keepdims=True)
        gradcam = np.sum(weights * target_activations, axis=1)
        gradcam = np.maximum(gradcam, 0)
        # gradcam = gradcam[:,:,0]
        print(gradcam.shape)
        gradcam = np.squeeze(gradcam,0)
        gradcam = cv2.resize(gradcam, (224, 224))
        gradcam = (gradcam - np.min(gradcam)) / (np.max(gradcam) - np.min(gradcam) + 1e-7)

        return gradcam

def visualize_gradcam(image_path, gradcam):
    image = np.array(Image.open(image_path))
    image = cv2.resize(image,(224,224))
       
    plt.subplot(1, 2, 1)
    plt.imshow(image)
    plt.title('Original Image')
    plt.axis('off')
   
    plt.subplot(1, 2, 2)
    plt.imshow(gradcam, cmap='jet', alpha=0.7)
    plt.imshow(image, alpha=0.5)
    plt.title('Grad-CAM')
    plt.axis('off')
   
    # plt.show()
    plt.savefig('test.jpg')
    print(image.shape,gradcam.shape)
   

if __name__ == "__main__":
    from torchvision.models import resnet50

    def load_model():
        model = resnet50(pretrained=True)
        model.eval()
        return model
   
    def get_transform():
        transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
        ])
        return transform
   
    def process_image(image_path):
        image = Image.open(image_path).convert("RGB")
        image_tensor = get_transform()(image).unsqueeze(0)
        return image_tensor
   
    def visualize_gradcam_test(image_path, gradcam):
        image = np.array(Image.open(image_path))
        image = cv2.resize(image,(224,224))
           
        plt.subplot(1, 2, 1)
        plt.imshow(image)
        plt.title('Original Image')
        plt.axis('off')
       
        plt.subplot(1, 2, 2)
        plt.imshow(gradcam, cmap='jet', alpha=0.7)
        plt.imshow(image, alpha=0.5)
        plt.title('Grad-CAM')
        plt.axis('off')
       
        # plt.show()
        plt.savefig('test.jpg')
        print(image.shape,gradcam.shape)
   
   
    model = load_model()
    for name, module in model.named_children():
        print(name)
    pipeline = GradCam(model,'layer4')
    image_path = r"example.jpg"  # 이미지 경로를 적절히 지정
   
    image_tensor = process_image(image_path)
    gradcam = pipeline.generate_gradcam(image_tensor)
   
    visualize_gradcam_test(image_path, gradcam)

여기서 pytorch의 register hook 기능을 사용하였는데 이는 모델(모듈)의 forward나 backward 동작 중에 트리거를 설정할 수 있는 기능이다. forward 동작에서 해당 모듈의 출력과 backward 동작에서 해당 모듈에서 output에 대한 gradient를 저장하도록 설정돼있음을 확인할 수 있다.

ImageNet으로 pre-train된 ResNet50을 사용했으며 터미널 출력은 다음과 같다.

imagenet_class_index.json파일에서 class를 찾아보면 pug로 옳은 추론 결과를 확인할 수 있다. 

json 파일은 아래 링크의 두 번째 답변에서 찾았다.

 

Imagenet classes

I found a map of id -> label https://gist.github.com/yrevar/942d3a0ac09ec9e5eb3a . So for example with open("imagenet1000_clsid_to_human.txt") as f: idx2label = eval(f.read()) for idx in out[0].sort()[1][-10:]: print(idx2label(idx)) will work, though eval

discuss.pytorch.org

 

728x90
반응형

댓글