본문 바로가기
DL

[ViT] Vision Transformer 구현 -0 Linear Projection of Flattened Pathes

by apsdfjoi 2023. 4. 21.
728x90
반응형

Vision Transformer를 모듈별로 구현하며 궁금한 점을 기록했다.

출처 : https://arxiv.org/abs/2010.11929

Vision Transformer(ViT)의 데이터(이미지) 처리 프로세스를 ChatGPT에게 물어봤다.

  1. 입력 이미지에서 패치 추출: 입력 이미지에서 패치를 추출합니다. 일반적으로는 16x16 크기의 패치를 사용합니다. 이는 이미지의 크기에 따라 조정될 수 있습니다.
  2. 임베딩: 추출한 각 패치는 먼저 임베딩(embedding)되어 벡터로 변환됩니다. 이를 위해 일반적으로 선형 변환(linear transformation)과 사전 학습된(Pre-trained) 모델을 사용합니다.
  3. 위치 임베딩: 임베딩된 패치의 위치 정보를 제공하기 위해 위치 임베딩(Positional embedding)이 추가됩니다. 이는 각 패치의 위치에 대한 고유한 값을 포함하는 벡터입니다.
  4. 인코더(Encoder) 블록: 인코더 블록은 여러 개의 셀프 어텐션(self-attention) 및 피드 포워드(feedforward) 레이어를 사용하여 패치를 처리합니다. 이러한 레이어는 입력 패치에 대한 내부 상호 작용과 특성 강화를 통해 패턴을 추출하고, 인코더 출력을 생성합니다.
  5. 출력 계층: 인코더의 출력은 분류(classification)를 위한 완전 연결(fully-connected) 또는 소프트맥스(softmax) 레이어로 전달됩니다.

ViT는 기존 기계 번역 모델인 Transformer 아키텍처를 그대로 사용하며, text가 아닌 image에서 embedding vector를 구하기 위해 패치로 나누고 'Linear Projection of Flattened Pathes' 작업을 거친다.

C : 이미지 channel ,  N : patch 개수, P : patch 크기(가로,세로)

이때, Projection 연산은 이미지를 일정한 크기로 나눈 patch들을 바로 Flatten 할 수도 있지만  Convolution 연산으로 진행한다.

chatGPT를 활용하여 작성한 코드는 다음과 같다.

class ImageEmbedding(nn.Module):
    def __init__(selfimage_sizepatch_sizeemb_size):
        super().__init__()

        self.patch_size = patch_size

        # 이미지 데이터를 패치(patch)로 분할
        self.projection = nn.Conv2d(in_channels=3, out_channels=emb_size, kernel_size=patch_size, stride=patch_size)

    def forward(selfx):
        x = self.projection(x)
        # 패치의 형태를 변환
        # (batch_size, emb_size, num_patches, num_patches)
        x = x.flatten(start_dim=2)
        x = x.transpose(12)

        return x

torchvision.models.VisionTransformer 클래스에서 구현된 코드는 다음과 같다.

else:
            self.conv_proj = nn.Conv2d(
                in_channels=3, out_channels=hidden_dim, kernel_size=patch_size, stride=patch_size
            )

 

def _process_input(self, x: torch.Tensor) -> torch.Tensor:
        n, c, h, w = x.shape
        p = self.patch_size
        torch._assert(h == self.image_size, f"Wrong image height! Expected {self.image_size} but got {h}!")
        torch._assert(w == self.image_size, f"Wrong image width! Expected {self.image_size} but got {w}!")
        n_h = h // p
        n_w = w // p

        # (n, c, h, w) -> (n, hidden_dim, n_h, n_w)
        x = self.conv_proj(x)
        # (n, hidden_dim, n_h, n_w) -> (n, hidden_dim, (n_h * n_w))
        x = x.reshape(n, self.hidden_dim, n_h * n_w)

        # (n, hidden_dim, (n_h * n_w)) -> (n, (n_h * n_w), hidden_dim)
        # The self attention layer expects inputs in the format (N, S, E)
        # where S is the source sequence length, N is the batch size, E is the
        # embedding dimension
        x = x.permute(0, 2, 1)

        return x

Convolution을 사용하지 않고 패치를 나누는 코드는 다음과 같다.

self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
            nn.LayerNorm(patch_dim),
            nn.Linear(patch_dim, dim),
            nn.LayerNorm(dim),
        )

위 코드의 출처는 https://github.com/lucidrains/vit-pytorch/blob/e1b08c15b9b237329d30324ce40579d4d4afc761/vit_pytorch/vit.py#L94 이다.

einops의 Rearrange를 사용하여 가독성 있게 텐서를 다루고 있다. 문득 텐서를 조작할 때, 연산 결과가 어떻게 다른지 확인해보고 싶어졌다.

def extract_patches(imgpatch_sizemethod = 0):
    # img: 입력 이미지 (torch.Tensor)
    # patch_size: 패치 크기 (int)

    print(img.shape)
    patches = F.unfold(img, kernel_size=patch_size, stride=patch_size)
    print(patches.shape)
    if method == 0:
      patches = patches.reshape(img.shape[0], img.shape[2]*img.shape[3]//patch_size**2, img.shape[1]*patch_size**2)

    elif method == 1:
      patches = nn.Sequential(
                  Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size))(img)
    
    elif method == 2:
      patches = patches.permute(0,2,1)
    
    elif method == 3:
      patches = nn.Flatten(start_dim=2)(patches)
      patches = patches.transpose(1,2)

    print(patches.shape)

    return patches

위 코드에서 input shape가 (B, C, H, W) 일 때, 2, 3번은 같은 출력을 보였다.

patches = nn.Sequential(
                  Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size))(img)

이 코드를 재현하고 싶은데 아직 방법을 찾지 못했다.

 

 

 

 

 

 

 

728x90
반응형

댓글