AI

Vision Transfromer (ViT) Pytorch 구현 코드 리뷰 - 1

kimy 2021. 2. 19. 16:21

개요

이미지를 패치들로 나누어 Transformer Encoder에 적용한 Vision Transformer의 구현코드 리뷰입니다.

github.com/FrancescoSaverioZuppichini/ViT

 

FrancescoSaverioZuppichini/ViT

Implementing Vi(sion)T(transformer). Contribute to FrancescoSaverioZuppichini/ViT development by creating an account on GitHub.

github.com

위 원본링크에 쉽고 자세하게 구현되어 있으나 공부목적으로 작성된 코드 리뷰입니다. 위 코드저자는 Einstein Notation 라이브러리들을 사용하여 각종 텐서계산을 구현하고 있습니다. 이러한 Einstein Notation은 한번 이해한다면 코드를 매우 직관적으로 파악할 수 있습니다.

 

ViT Architecture

ViT Architecture

ViT의 구조입니다. 원래의 Transformer에서 Encoder 부분만 사용하는 구조입니다. 입력 이미지는 PxP 사이즈의 패치들로 나눠지고 flatten 하게되면 1차원의 벡터형태로 Encoder에 들어가게 됩니다. Embedding 된 Word Vector와 같은 개념입니다. 그 후 MHA(Multi Head Attention)을 거친 후 Feed Forward Layer를 거치게 되는게 인코더의 한 블럭입니다.

Patch Embedding

이미지를 패치들로 나누어 Embedding 시켜주는 부분의 코드입니다. Embedding 시 class token과 positional embedding을 추가하게 됩니다.

import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt

from torch import nn
from torch import Tensor
from PIL import Image
from torchvision.transforms import Compose, Resize, ToTensor
from einops import rearrange, reduce, repeat
from einops.layers.torch import Rearrange, Reduce
from torchsummary import summary

필요한 라이브러리와 프레임워크입니다.

x = torch.randn(8, 3, 224, 224)
x.shape
torch.Size([8, 3, 224, 224])

배치사이즈 8, 채널 3, h, w = (224, 224)를 갖는 랜덤텐서를 사용하여 텐서연산의 과정을 살펴보도록 하겠습니다. 먼저 BATCHxCxH×W 형태를 가진 이미지를 BATCHxNx(P*P*C)의 벡터로 임베딩을 해주어야 합니다. P는 패치사이즈이며 N은 패치의 개수(H*W / (P*P))입니다.

patch_size = 16 # 16 pixels

print('x :', x.shape)
patches = rearrange(x, 'b c (h s1) (w s2) -> b (h w) (s1 s2 c)', s1=patch_size, s2=patch_size)
print('patches :', patches.shape)
x : torch.Size([8, 3, 224, 224])
patches : torch.Size([8, 196, 768])

einops의 rearrange를 이용하면 BATCHxCxHxW를 가진 텐서를 BATCHxNxPSIZE로 바꿔줄 수 있습니다. 여기서의 einstein operation을 보면 "8x3x(14*16)x(14*16) -> 8x(14*14)x(16*16*3)" 형태로 바뀌는 것을 확인할 수 있습니다. einops를 이용하면 이미지를 패치로 나누고 flatten하는 과정을 한번에 완성할 수 있습니다.

 

하지만 실제의 ViT에서는 einops같은 Linear Embedding이 아니라 kernal size와 stride size를 patch size로 갖는 Convolutional 2D Layer를 이용한 후 flatten 시켜줍니다. 이렇게 하면 performance gain이 있다고 저자는 말합니다.

patch_size = 16
in_channels = 3
emb_size = 768

projection = nn.Sequential(
            # using a conv layer instead of a linear one -> performance gains
            nn.Conv2d(in_channels, emb_size, kernel_size=patch_size, stride=patch_size),
            Rearrange('b e (h) (w) -> b (h w) e'),
        )

projection(x).shape
torch.Size([8, 196, 768])

output은 아까의 텐서와 같은 사이즈입니다.

 

다음은 Cls Token과 Positional Encoding을 추가하는 코드입니다.

emb_size = 768
img_size = 224
patch_size = 16

# 이미지를 패치사이즈로 나누고 flatten
projected_x = projection(x)
print('Projected X shape :', projected_x.shape)

# cls_token과 pos encoding Parameter 정의
cls_token = nn.Parameter(torch.randn(1,1, emb_size))
positions = nn.Parameter(torch.randn((img_size // patch_size) **2 + 1, emb_size))
print('Cls Shape :', cls_token.shape, ', Pos Shape :', positions.shape)

# cls_token을 반복하여 배치사이즈의 크기와 맞춰줌
batch_size = 8
cls_tokens = repeat(cls_token, '() n e -> b n e', b=batch_size)
print('Repeated Cls shape :', cls_tokens.shape)

# cls_token과 projected_x를 concatenate
cat_x = torch.cat([cls_tokens, projected_x], dim=1)

# position encoding을 더해줌
cat_x += positions
print('output : ', cat_x.shape)
Projected X shape : torch.Size([8, 196, 768])
Cls Shape : torch.Size([1, 1, 768]) , Pos Shape : torch.Size([197, 768])
Repeated Cls shape : torch.Size([8, 1, 768])
output :  torch.Size([8, 197, 768])

x를 패치로 나누고 flatten시키면 (8, 196, 768)의 텐서가 됩니다. 그 후 cls_token을 맨 앞에 붙여서 (8, 197, 768)의 텐서로 만들기 위해 cls_token을 (1, 1, 768)의 파라미터로 생성해줍니다. 생성된 파라미터는 einops의 repeat을 이용하여 (8, 1, 768) 사이즈로 확장됩니다. 이후 torch.cat을 통하여 dim=1인 차원에 concatenate시켜줍니다.

Position Encoding은 cls_token으로 늘어난 크기에 맞춰 1을 더한 (197, 768) 사이즈로 생성후 마지막에 더해주면 됩니다. 브로드캐스팅된 + 연산은 모든 배치에 같은 Pos Encoding을 더하게 됩니다.

 

하나의 클래스로 구현한 코드입니다.

class PatchEmbedding(nn.Module):
    def __init__(self, in_channels: int = 3, patch_size: int = 16, emb_size: int = 768, img_size: int = 224):
        self.patch_size = patch_size
        super().__init__()
        self.projection = nn.Sequential(
            # using a conv layer instead of a linear one -> performance gains
            nn.Conv2d(in_channels, emb_size, kernel_size=patch_size, stride=patch_size),
            Rearrange('b e (h) (w) -> b (h w) e'),
        )
        self.cls_token = nn.Parameter(torch.randn(1,1, emb_size))
        self.positions = nn.Parameter(torch.randn((img_size // patch_size) **2 + 1, emb_size))
        
    def forward(self, x: Tensor) -> Tensor:
        b, _, _, _ = x.shape
        x = self.projection(x)
        cls_tokens = repeat(self.cls_token, '() n e -> b n e', b=b)
        # prepend the cls token to the input
        x = torch.cat([cls_tokens, x], dim=1)
        # add position embedding
        x += self.positions

        return x

 

결론

이미지를 입력받아 패치사이즈로 나누고 1차원 벡터로 projection 시킨 후 cls token과 positional encoding하는 부분까지 진행하였습니다. 다음 글은 Multi Head Attention 부분입니다.