github.com/FrancescoSaverioZuppichini/ViT
위 코드를 참고하여 리뷰했습니다.
개요
패치임베딩까지 진행하였고 이번에는 Multi Head Attention을 진행해보도록 하겠습니다.
MHA(Multi Head Attention)
MHA는 위 그림과 같이 진행됩니다. VIT에서의 MHA는 QKV가 같은 텐서로 입력됩니다. 입력텐서는 3개의 Linear Projection을 통해 임베딩된 후 여러 개의 Head로 나눠진 후 각각 Scaled Dot-Product Attention을 진행합니다.
Linear Projection
emb_size = 768
num_heads = 8
keys = nn.Linear(emb_size, emb_size)
queries = nn.Linear(emb_size, emb_size)
values = nn.Linear(emb_size, emb_size)
print(keys, queries, values)
Linear(in_features=768, out_features=768, bias=True) Linear(in_features=768, out_features=768, bias=True) Linear(in_features=768, out_features=768, bias=True)
먼저 이전 글에서 임베딩된 입력텐서를 받아서 다시 임베딩사이즈로 Linear Projection을 하는 레이어를 3개 만듭니다. 입력 텐서를 QKV로 만드는 각 레이어는 모델 훈련과정에서 학습됩니다.
Multi-Head
queries = rearrange(queries(x), "b n (h d) -> b h n d", h=num_heads)
keys = rearrange(keys(x), "b n (h d) -> b h n d", h=num_heads)
values = rearrange(values(x), "b n (h d) -> b h n d", h=num_heads)
print('shape :', queries.shape, keys.shape, values.shape)
shape : torch.Size([8, 8, 197, 96]) torch.Size([8, 8, 197, 96]) torch.Size([8, 8, 197, 96])
이후 각 Linear Projection을 거친 QKV를 rearrange를 통해 8개의 Multi-Head로 나눠주게 됩니다.
Scaled Dot Product Attention
scaled dot-product attention을 구현한 코드입니다.
# Queries * Keys
energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys)
print('energy :', energy.shape)
# Get Attention Score
scaling = emb_size ** (1/2)
att = F.softmax(energy, dim=-1) / scaling
print('att :', att.shape)
# Attention Score * values
out = torch.einsum('bhal, bhlv -> bhav ', att, values)
print('out :', out.shape)
# Rearrage to emb_size
out = rearrange(out, "b h n d -> b n (h d)")
print('out2 : ', out.shape)
energy : torch.Size([8, 8, 197, 197])
att : torch.Size([8, 8, 197, 197])
out : torch.Size([8, 8, 197, 96])
out2 : torch.Size([8, 197, 768])
위 그림과 같이 Q와 K를 곱합니다. einops를 이용해 자동으로 transpose 후 내적이 진행됩니다. 그다음 scaling 해준 후 얻어진 Attention Score와 V를 내적하고 다시 emb_size로 rearrange 하면 MHA의 output이 나오게 됩니다.
클래스로 구현된 코드
class MultiHeadAttention(nn.Module):
def __init__(self, emb_size: int = 768, num_heads: int = 8, dropout: float = 0):
super().__init__()
self.emb_size = emb_size
self.num_heads = num_heads
# fuse the queries, keys and values in one matrix
self.qkv = nn.Linear(emb_size, emb_size * 3)
self.att_drop = nn.Dropout(dropout)
self.projection = nn.Linear(emb_size, emb_size)
def forward(self, x : Tensor, mask: Tensor = None) -> Tensor:
# split keys, queries and values in num_heads
qkv = rearrange(self.qkv(x), "b n (h d qkv) -> (qkv) b h n d", h=self.num_heads, qkv=3)
queries, keys, values = qkv[0], qkv[1], qkv[2]
# sum up over the last axis
energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys) # batch, num_heads, query_len, key_len
if mask is not None:
fill_value = torch.finfo(torch.float32).min
energy.mask_fill(~mask, fill_value)
scaling = self.emb_size ** (1/2)
att = F.softmax(energy, dim=-1) / scaling
att = self.att_drop(att)
# sum up over the third axis
out = torch.einsum('bhal, bhlv -> bhav ', att, values)
out = rearrange(out, "b h n d -> b n (h d)")
out = self.projection(out)
return out
최종코드에서는 QKV 당 각각 1개씩의 Linear Layer를 적용한 것을 텐서 연산을 한번에 하기 위해 Linear Layer를 emb_size*3으로 설정한 후 연산시 QKV를 각각 나눠주게 됩니다. 또한 Attention 시 무시할 정보를 설정하기 위한 masking 코드도 추가되었습니다. 마지막으로 나오는 out은 최종적으로 한번의 Linear Layer를 거쳐서 나오게 되는게 MHA의 모든 구현입니다.
결론
VIT의 핵심인 MHA를 구현하였고 다음 글은 남은 부분인 Resiudal, MLP, 레이어 적층 등을 리뷰하도록 하겠습니다.
'AI' 카테고리의 다른 글
TransUNet - Transformer를 적용한 Segmentation Model 논문 리뷰 (0) | 2021.02.25 |
---|---|
Vision Transfromer (ViT) Pytorch 구현 코드 리뷰 - 3 (2) | 2021.02.22 |
Vision Transfromer (ViT) Pytorch 구현 코드 리뷰 - 1 (3) | 2021.02.19 |
Semantic Segmentation information Links (0) | 2021.02.19 |
Resnet 18-layer pytorch 코드 리뷰 (0) | 2021.02.16 |