AI

Vision Transfromer (ViT) Pytorch 구현 코드 리뷰 - 2

kimy 2021. 2. 22. 11:26

github.com/FrancescoSaverioZuppichini/ViT

 

FrancescoSaverioZuppichini/ViT

Implementing Vi(sion)T(transformer). Contribute to FrancescoSaverioZuppichini/ViT development by creating an account on GitHub.

github.com

위 코드를 참고하여 리뷰했습니다.

 

개요

패치임베딩까지 진행하였고 이번에는 Multi Head Attention을 진행해보도록 하겠습니다.

 

MHA(Multi Head Attention)

Multi Head Attention

MHA는 위 그림과 같이 진행됩니다. VIT에서의 MHA는 QKV가 같은 텐서로 입력됩니다. 입력텐서는 3개의 Linear Projection을 통해 임베딩된 후 여러 개의 Head로 나눠진 후 각각 Scaled Dot-Product Attention을 진행합니다.

 

Linear Projection

emb_size = 768
num_heads = 8

keys = nn.Linear(emb_size, emb_size)
queries = nn.Linear(emb_size, emb_size)
values = nn.Linear(emb_size, emb_size)
print(keys, queries, values)
Linear(in_features=768, out_features=768, bias=True) Linear(in_features=768, out_features=768, bias=True) Linear(in_features=768, out_features=768, bias=True)

먼저 이전 글에서 임베딩된 입력텐서를 받아서 다시 임베딩사이즈로 Linear Projection을 하는 레이어를 3개 만듭니다. 입력 텐서를 QKV로 만드는 각 레이어는 모델 훈련과정에서 학습됩니다.

 

Multi-Head

queries = rearrange(queries(x), "b n (h d) -> b h n d", h=num_heads)
keys = rearrange(keys(x), "b n (h d) -> b h n d", h=num_heads)
values  = rearrange(values(x), "b n (h d) -> b h n d", h=num_heads)

print('shape :', queries.shape, keys.shape, values.shape)
shape : torch.Size([8, 8, 197, 96]) torch.Size([8, 8, 197, 96]) torch.Size([8, 8, 197, 96])

이후 각 Linear Projection을 거친 QKV를 rearrange를 통해 8개의 Multi-Head로 나눠주게 됩니다.

 

Scaled Dot Product Attention

scaled dot-product attention을 구현한 코드입니다.

# Queries * Keys
energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys)
print('energy :', energy.shape)

# Get Attention Score
scaling = emb_size ** (1/2)
att = F.softmax(energy, dim=-1) / scaling
print('att :', att.shape)

# Attention Score * values
out = torch.einsum('bhal, bhlv -> bhav ', att, values)
print('out :', out.shape)

# Rearrage to emb_size
out = rearrange(out, "b h n d -> b n (h d)")
print('out2 : ', out.shape)
energy : torch.Size([8, 8, 197, 197])
att : torch.Size([8, 8, 197, 197])
out : torch.Size([8, 8, 197, 96])
out2 :  torch.Size([8, 197, 768])

위 그림과 같이 Q와 K를 곱합니다. einops를 이용해 자동으로 transpose 후 내적이 진행됩니다. 그다음 scaling 해준 후 얻어진 Attention Score와 V를 내적하고 다시 emb_size로 rearrange 하면 MHA의 output이 나오게 됩니다.

 

클래스로 구현된 코드

class MultiHeadAttention(nn.Module):
    def __init__(self, emb_size: int = 768, num_heads: int = 8, dropout: float = 0):
        super().__init__()
        self.emb_size = emb_size
        self.num_heads = num_heads
        # fuse the queries, keys and values in one matrix
        self.qkv = nn.Linear(emb_size, emb_size * 3)
        self.att_drop = nn.Dropout(dropout)
        self.projection = nn.Linear(emb_size, emb_size)
        
    def forward(self, x : Tensor, mask: Tensor = None) -> Tensor:
        # split keys, queries and values in num_heads
        qkv = rearrange(self.qkv(x), "b n (h d qkv) -> (qkv) b h n d", h=self.num_heads, qkv=3)
        queries, keys, values = qkv[0], qkv[1], qkv[2]
        # sum up over the last axis
        energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys) # batch, num_heads, query_len, key_len
        if mask is not None:
            fill_value = torch.finfo(torch.float32).min
            energy.mask_fill(~mask, fill_value)
            
        scaling = self.emb_size ** (1/2)
        att = F.softmax(energy, dim=-1) / scaling
        att = self.att_drop(att)
        # sum up over the third axis
        out = torch.einsum('bhal, bhlv -> bhav ', att, values)
        out = rearrange(out, "b h n d -> b n (h d)")
        out = self.projection(out)
        return out

최종코드에서는 QKV 당 각각 1개씩의 Linear Layer를 적용한 것을 텐서 연산을 한번에 하기 위해 Linear Layer를 emb_size*3으로 설정한 후 연산시 QKV를 각각 나눠주게 됩니다. 또한 Attention 시 무시할 정보를 설정하기 위한 masking 코드도 추가되었습니다. 마지막으로 나오는 out은 최종적으로 한번의 Linear Layer를 거쳐서 나오게 되는게 MHA의 모든 구현입니다.

 

결론

VIT의 핵심인 MHA를 구현하였고 다음 글은 남은 부분인 Resiudal, MLP, 레이어 적층 등을 리뷰하도록 하겠습니다.