github.com/FrancescoSaverioZuppichini/ViT
위 코드를 참고하여 리뷰하였습니다.
개요
지난 글에서 patch embedding에 이어 multi head attention까지 진행하였고 이제는 VIT Encoder 구조를 구현해 보겠습니다.
Residual Block
class ResidualAdd(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
def forward(self, x, **kwargs):
res = x
x = self.fn(x, **kwargs)
x += res
return x
Residual Connection을 구현한 클래스입니다. 나중에 fn을 입력받아 fn의 forward 후 res를 더해 사용하게 됩니다.
Feed Forward MLP
class FeedForwardBlock(nn.Sequential):
def __init__(self, emb_size: int, expansion: int = 4, drop_p: float = 0.):
super().__init__(
nn.Linear(emb_size, expansion * emb_size),
nn.GELU(),
nn.Dropout(drop_p),
nn.Linear(expansion * emb_size, emb_size),
)
다음은 MHA 이후에 진행되는 MLP 부분입니다. Linear - GELU - Dropout - Linear 순으로 진행되며 두개의 Linear 레이어가 있는 것을 확인할 수 있으며 첫번째 레이어에서는 expansion을 곱해준 만큼 임베딩 사이즈를 확장하고 GELU와 Dropout 후에 두번째 Linear 레이어에서 다시 원래의 emb_size로 축소해주게 됩니다.
Transformer Encoder Block
class TransformerEncoderBlock(nn.Sequential):
def __init__(self,
emb_size: int = 768,
drop_p: float = 0.,
forward_expansion: int = 4,
forward_drop_p: float = 0.,
** kwargs):
super().__init__(
ResidualAdd(nn.Sequential(
nn.LayerNorm(emb_size),
MultiHeadAttention(emb_size, **kwargs),
nn.Dropout(drop_p)
)),
ResidualAdd(nn.Sequential(
nn.LayerNorm(emb_size),
FeedForwardBlock(
emb_size, expansion=forward_expansion, drop_p=forward_drop_p),
nn.Dropout(drop_p)
)
))
지금까지 구현한 것을 모두 넣으면 하나의 Transformer Encoder Block을 만들 수 있습니다. nn.Module 대신 nn.Sequential을 상속받아 forward의 재정의 없이 구현되었습니다.
x = torch.randn(8, 3, 224, 224)
patches_embedded = PatchEmbedding()(x)
TransformerEncoderBlock()(patches_embedded).shape
torch.Size([8, 197, 768])
테스트용 텐서를 생성하여 입력한 결과입니다.
Block 쌓기
class TransformerEncoder(nn.Sequential):
def __init__(self, depth: int = 12, **kwargs):
super().__init__(*[TransformerEncoderBlock(**kwargs) for _ in range(depth)])
마찬가지로 nn.Sequential을 상속받은 후 Encoder Block을 depth 만큼 쌓아줍니다.
*[TransformerEncoderBlock(**kwargs) for _ in range(depth)] 에서 앞에 *이 붙은 이유는 인자를 리스트형식으로 보내는게 아니라 각각 나눠서 보내줘야되기 때문입니다. 예를 들면 인자를 [1,2,3]으로 넣을 경우 함수에서는 [1,2,3]으로 받지만 *[1, 2, 3]일 경우 1, 2, 3 으로 각각 나눠진 후 들어갑니다.
Head
class ClassificationHead(nn.Sequential):
def __init__(self, emb_size: int = 768, n_classes: int = 1000):
super().__init__(
Reduce('b n e -> b e', reduction='mean'),
nn.LayerNorm(emb_size),
nn.Linear(emb_size, n_classes))
마지막 Head Layer는 Classification을 위한 Layer 입니다. emb_size의 1차원 벡터로 projection 후 LayerNorm과 nn.Linear를 거치면 완성입니다.
Summary
class ViT(nn.Sequential):
def __init__(self,
in_channels: int = 3,
patch_size: int = 16,
emb_size: int = 768,
img_size: int = 224,
depth: int = 12,
n_classes: int = 1000,
**kwargs):
super().__init__(
PatchEmbedding(in_channels, patch_size, emb_size, img_size),
TransformerEncoder(depth, emb_size=emb_size, **kwargs),
ClassificationHead(emb_size, n_classes)
)
summary(ViT(), (3, 224, 224), device='cpu')
ViT 생성 및 Summary 결과입니다.
----------------------------------------------------------------
Layer (type) Output Shape Param #
================================================================
Conv2d-1 [-1, 768, 14, 14] 590,592
Rearrange-2 [-1, 196, 768] 0
PatchEmbedding-3 [-1, 197, 768] 0
LayerNorm-4 [-1, 197, 768] 1,536
Linear-5 [-1, 197, 2304] 1,771,776
Dropout-6 [-1, 8, 197, 197] 0
Linear-7 [-1, 197, 768] 590,592
MultiHeadAttention-8 [-1, 197, 768] 0
Dropout-9 [-1, 197, 768] 0
ResidualAdd-10 [-1, 197, 768] 0
LayerNorm-11 [-1, 197, 768] 1,536
Linear-12 [-1, 197, 3072] 2,362,368
GELU-13 [-1, 197, 3072] 0
Dropout-14 [-1, 197, 3072] 0
Linear-15 [-1, 197, 768] 2,360,064
Dropout-16 [-1, 197, 768] 0
ResidualAdd-17 [-1, 197, 768] 0
LayerNorm-18 [-1, 197, 768] 1,536
Linear-19 [-1, 197, 2304] 1,771,776
Dropout-20 [-1, 8, 197, 197] 0
Linear-21 [-1, 197, 768] 590,592
MultiHeadAttention-22 [-1, 197, 768] 0
Dropout-23 [-1, 197, 768] 0
ResidualAdd-24 [-1, 197, 768] 0
LayerNorm-25 [-1, 197, 768] 1,536
Linear-26 [-1, 197, 3072] 2,362,368
GELU-27 [-1, 197, 3072] 0
Dropout-28 [-1, 197, 3072] 0
Linear-29 [-1, 197, 768] 2,360,064
Dropout-30 [-1, 197, 768] 0
ResidualAdd-31 [-1, 197, 768] 0
LayerNorm-32 [-1, 197, 768] 1,536
Linear-33 [-1, 197, 2304] 1,771,776
Dropout-34 [-1, 8, 197, 197] 0
Linear-35 [-1, 197, 768] 590,592
MultiHeadAttention-36 [-1, 197, 768] 0
Dropout-37 [-1, 197, 768] 0
ResidualAdd-38 [-1, 197, 768] 0
LayerNorm-39 [-1, 197, 768] 1,536
Linear-40 [-1, 197, 3072] 2,362,368
GELU-41 [-1, 197, 3072] 0
Dropout-42 [-1, 197, 3072] 0
Linear-43 [-1, 197, 768] 2,360,064
Dropout-44 [-1, 197, 768] 0
ResidualAdd-45 [-1, 197, 768] 0
LayerNorm-46 [-1, 197, 768] 1,536
Linear-47 [-1, 197, 2304] 1,771,776
Dropout-48 [-1, 8, 197, 197] 0
Linear-49 [-1, 197, 768] 590,592
MultiHeadAttention-50 [-1, 197, 768] 0
Dropout-51 [-1, 197, 768] 0
ResidualAdd-52 [-1, 197, 768] 0
LayerNorm-53 [-1, 197, 768] 1,536
Linear-54 [-1, 197, 3072] 2,362,368
GELU-55 [-1, 197, 3072] 0
Dropout-56 [-1, 197, 3072] 0
Linear-57 [-1, 197, 768] 2,360,064
Dropout-58 [-1, 197, 768] 0
ResidualAdd-59 [-1, 197, 768] 0
LayerNorm-60 [-1, 197, 768] 1,536
Linear-61 [-1, 197, 2304] 1,771,776
Dropout-62 [-1, 8, 197, 197] 0
Linear-63 [-1, 197, 768] 590,592
MultiHeadAttention-64 [-1, 197, 768] 0
Dropout-65 [-1, 197, 768] 0
ResidualAdd-66 [-1, 197, 768] 0
LayerNorm-67 [-1, 197, 768] 1,536
Linear-68 [-1, 197, 3072] 2,362,368
GELU-69 [-1, 197, 3072] 0
Dropout-70 [-1, 197, 3072] 0
Linear-71 [-1, 197, 768] 2,360,064
Dropout-72 [-1, 197, 768] 0
ResidualAdd-73 [-1, 197, 768] 0
LayerNorm-74 [-1, 197, 768] 1,536
Linear-75 [-1, 197, 2304] 1,771,776
Dropout-76 [-1, 8, 197, 197] 0
Linear-77 [-1, 197, 768] 590,592
MultiHeadAttention-78 [-1, 197, 768] 0
Dropout-79 [-1, 197, 768] 0
ResidualAdd-80 [-1, 197, 768] 0
LayerNorm-81 [-1, 197, 768] 1,536
Linear-82 [-1, 197, 3072] 2,362,368
GELU-83 [-1, 197, 3072] 0
Dropout-84 [-1, 197, 3072] 0
Linear-85 [-1, 197, 768] 2,360,064
Dropout-86 [-1, 197, 768] 0
ResidualAdd-87 [-1, 197, 768] 0
LayerNorm-88 [-1, 197, 768] 1,536
Linear-89 [-1, 197, 2304] 1,771,776
Dropout-90 [-1, 8, 197, 197] 0
Linear-91 [-1, 197, 768] 590,592
MultiHeadAttention-92 [-1, 197, 768] 0
Dropout-93 [-1, 197, 768] 0
ResidualAdd-94 [-1, 197, 768] 0
LayerNorm-95 [-1, 197, 768] 1,536
Linear-96 [-1, 197, 3072] 2,362,368
GELU-97 [-1, 197, 3072] 0
Dropout-98 [-1, 197, 3072] 0
Linear-99 [-1, 197, 768] 2,360,064
Dropout-100 [-1, 197, 768] 0
ResidualAdd-101 [-1, 197, 768] 0
LayerNorm-102 [-1, 197, 768] 1,536
Linear-103 [-1, 197, 2304] 1,771,776
Dropout-104 [-1, 8, 197, 197] 0
Linear-105 [-1, 197, 768] 590,592
MultiHeadAttention-106 [-1, 197, 768] 0
Dropout-107 [-1, 197, 768] 0
ResidualAdd-108 [-1, 197, 768] 0
LayerNorm-109 [-1, 197, 768] 1,536
Linear-110 [-1, 197, 3072] 2,362,368
GELU-111 [-1, 197, 3072] 0
Dropout-112 [-1, 197, 3072] 0
Linear-113 [-1, 197, 768] 2,360,064
Dropout-114 [-1, 197, 768] 0
ResidualAdd-115 [-1, 197, 768] 0
LayerNorm-116 [-1, 197, 768] 1,536
Linear-117 [-1, 197, 2304] 1,771,776
Dropout-118 [-1, 8, 197, 197] 0
Linear-119 [-1, 197, 768] 590,592
MultiHeadAttention-120 [-1, 197, 768] 0
Dropout-121 [-1, 197, 768] 0
ResidualAdd-122 [-1, 197, 768] 0
LayerNorm-123 [-1, 197, 768] 1,536
Linear-124 [-1, 197, 3072] 2,362,368
GELU-125 [-1, 197, 3072] 0
Dropout-126 [-1, 197, 3072] 0
Linear-127 [-1, 197, 768] 2,360,064
Dropout-128 [-1, 197, 768] 0
ResidualAdd-129 [-1, 197, 768] 0
LayerNorm-130 [-1, 197, 768] 1,536
Linear-131 [-1, 197, 2304] 1,771,776
Dropout-132 [-1, 8, 197, 197] 0
Linear-133 [-1, 197, 768] 590,592
MultiHeadAttention-134 [-1, 197, 768] 0
Dropout-135 [-1, 197, 768] 0
ResidualAdd-136 [-1, 197, 768] 0
LayerNorm-137 [-1, 197, 768] 1,536
Linear-138 [-1, 197, 3072] 2,362,368
GELU-139 [-1, 197, 3072] 0
Dropout-140 [-1, 197, 3072] 0
Linear-141 [-1, 197, 768] 2,360,064
Dropout-142 [-1, 197, 768] 0
ResidualAdd-143 [-1, 197, 768] 0
LayerNorm-144 [-1, 197, 768] 1,536
Linear-145 [-1, 197, 2304] 1,771,776
Dropout-146 [-1, 8, 197, 197] 0
Linear-147 [-1, 197, 768] 590,592
MultiHeadAttention-148 [-1, 197, 768] 0
Dropout-149 [-1, 197, 768] 0
ResidualAdd-150 [-1, 197, 768] 0
LayerNorm-151 [-1, 197, 768] 1,536
Linear-152 [-1, 197, 3072] 2,362,368
GELU-153 [-1, 197, 3072] 0
Dropout-154 [-1, 197, 3072] 0
Linear-155 [-1, 197, 768] 2,360,064
Dropout-156 [-1, 197, 768] 0
ResidualAdd-157 [-1, 197, 768] 0
LayerNorm-158 [-1, 197, 768] 1,536
Linear-159 [-1, 197, 2304] 1,771,776
Dropout-160 [-1, 8, 197, 197] 0
Linear-161 [-1, 197, 768] 590,592
MultiHeadAttention-162 [-1, 197, 768] 0
Dropout-163 [-1, 197, 768] 0
ResidualAdd-164 [-1, 197, 768] 0
LayerNorm-165 [-1, 197, 768] 1,536
Linear-166 [-1, 197, 3072] 2,362,368
GELU-167 [-1, 197, 3072] 0
Dropout-168 [-1, 197, 3072] 0
Linear-169 [-1, 197, 768] 2,360,064
Dropout-170 [-1, 197, 768] 0
ResidualAdd-171 [-1, 197, 768] 0
Reduce-172 [-1, 768] 0
LayerNorm-173 [-1, 768] 1,536
Linear-174 [-1, 1000] 769,000
================================================================
Total params: 86,415,592
Trainable params: 86,415,592
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 364.33
Params size (MB): 329.65
Estimated Total Size (MB): 694.56
----------------------------------------------------------------
결론
ViT의 코드 리뷰를 진행하였는데 인코더-디코더 구조의 원래 Transformer와 달리 ViT는 인코더만 있는 구조기 때문에 코드가 길지 않고 쉽게 구현이 가능하다고 생각됩니다.
'AI' 카테고리의 다른 글
Lambda Networks 논문 리뷰 (0) | 2021.03.02 |
---|---|
TransUNet - Transformer를 적용한 Segmentation Model 논문 리뷰 (0) | 2021.02.25 |
Vision Transfromer (ViT) Pytorch 구현 코드 리뷰 - 2 (2) | 2021.02.22 |
Vision Transfromer (ViT) Pytorch 구현 코드 리뷰 - 1 (3) | 2021.02.19 |
Semantic Segmentation information Links (0) | 2021.02.19 |