AI

TransUNet - Transformer를 적용한 Segmentation Model 논문 리뷰

kimy 2021. 2. 25. 11:28

개요

2월 8일에 나온 Medical Image Segmentation을 목적으로 만들어진 TransUNet입니다. TransUNet은 기존의 발표된 ViT(Vision Transformer)를 이용해 인코딩 후 디코더를 이용해 Upsampling하여 Segmentation을 진행합니다. U-Net이 들어간 이름에서 알 수 있듯이 Upsampling시 기존 U-Net처럼 Skip Connection을 이용해 Segmentation Detail을 더 높여주는 방식을 취했습니다. TransUNet은 multi-organ CT Segmentation에서 State-of-the-arts 성능을 달성했습니다.

 

Architecture

TransUNet의 Architecture입니다. 여기서 중요한 점은 다음과 같습니다.

  • Skip Connection을 적용하기 위하여 기존의 ViT대신 CNN과 결합한 R50-ViT구조를 사용하게 됩니다. 기존의 ViT는 CNN을 거치지 않고 바로 패치들로 나누어 인코딩되게 되어 Upsampling시 Detail의 손실이 있습니다.
  • Skip Connection은 총 3개로 1/2 해상도, 1/4 해상도, 1/8 해상도에서 Concatenate 시켜줍니다.
  • ViT 인코딩시 (n_patch, D) 형태로 나오기 때문에 위치정보가 사라지므로 ViT에서 위치 인코딩을 더해줍니다. 이후 나온 인코딩 결과인 (n_patch, D)를 (n_patch, H, W) 형태로 변환해줍니다.
  • Upsampling할 때 concatenate 후 Conv3x3과 RELU를 적용합니다.

 

Image Squentialization

이미지를 입력받으면 HxWxC의 Shape입니다. ViT는 이미지를 패치사이즈 만큼 나눠서 Sequence 데이터로 만들어 입력하게 됩니다.

패치사이즈와 개수

따라서 이미지를 1차원 패치로 만들기 위해 위와 같은 형태로 변환하게 됩니다. Xi는 i=1, ..., N까지 있으며 개수는 HW/P^2으로 나오게 됩니다.

 

Patch Embedding

다음은 나눠진 패치들을 D차원의 latent 벡터로 임베딩시켜준 후 위치 정보의 손실을 막기위한 Positional Encoding을 더해주는 작업입니다.

패치에 E를 곱해주어 D차원으로 임베딩 시킨 후 NxD형태의 위치정보를 더해주면 Embedding의 완료입니다.

 

Encoding

임베딩된 z를 Transformer처럼 Mutihead Self Attention을 시켜주게 됩니다. MSA 전에는 Layer Normalization을 진행해줍니다. 뒤에 더해주는 z(l-1)은 Resnet처럼 Residual을 더해주는 부분입니다.

Attention 진행 후 LN을 거친 후 Multi-Layer Perceptron 후 Residual을 더하게 되면 한 인코더 Block이 완성됩니다. 논문에서는 원래의 ViT처럼 인코더 블럭을 총 12개 쌓아서 진행했습니다.

 

CNN-Transformer Hybrid as Encoder

Encoder

TransUNet에서는 Pure한 ViT와 달리 Resnet-50과 결합한 R50-ViT 구조를 사용합니다. ResNet의 마지막 feature map을 input으로 사용하며 각 패치사이즈는 1x1을 사용합니다. 이 디자인을 사용하면 디코더에서 중간의 고해상도 CNN Feature map을 끌어낼 수 있습니다. 그리고 논문에서 Pure ViT보다 더 좋은 성능을 냈다고 합니다.

 

Cascaded Upsampler (CUP)

CUP은 segmentation mask를 위한 여러개의 upsampling step으로 구성되어 있습니다. 인코더의 hidden feature(NxD)를 다시 (H/P x W/P x D)로 reshape 후 CUP을 이용해 (HxW)로 upsampling 하게 됩니다. 각 블럭은 2배 upsampling operator와 3x3 convolution layer, ReLU layer로 구성되어 있습니다.

TransUNet은 Hybird-Transformer와 CUP을 결합한 구조에 UNet처럼 Skip Connection을 적용한 구조입니다.

 

Experiments

multi-organ CT dataset을 이용해 실험한 결과입니다. 기존의 SoTA를 달성한 V-Net, DARR, U-Net, AttnUNet보다 좋은 성능을 보이는 것을 볼 수 있습니다.

  • ViT-None은 Non-Hybrid ViT 인코딩 후 원본 사이즈로 바로 업스케일링 한 모델입니다.
  • ViT-CUP은 Non-Hybrid ViT 인코딩 후 CUP을 사용해 단계적으로 업스케일링 한 모델입니다.
  • R50-ViT-CUP은 Hybrid ViT 인코더와 CUP을 사용한 모델입니다.
  • TransUNet은 R50-ViT-CUP에 Unet skip connection을 적용한 모델입니다.

결론적으로는 Hybrid ViT 인코더와 Cascaded Upsampler, Skip Connection을 적용했을 경우 SoTA의 성능을 나타내는 것을 확인할  수 있습니다.

 

Skip Connection의 개수에 따른 결과입니다. 더 많을수록 더 좋은 결과를 보입니다.

 

패치사이즈를 줄일수록 더 좋은결과를 보여주지만 논문에서는 16x16을 default size로 사용하였습니다.

 

ViT의 Large 모델은 더 나은 performance를 보여주지만 computation cost를 고려했을때 Base모델을 선택했다고 합니다.

 

각종 모델들과 GroundTruth의 비교결과입니다 TransUNet이 가장 Ground Truth에 일치하는 Prediction 결과를 보여줍니다.

 

결론

Transformer는 강력한 self-attention을 가진 모델로 알려져 있습니다. 이 논문에서는 트랜스포머를 medical image segmentation에 적용하는 첫번째 연구를 수행했습니다. TransUNet은 이미지 features를 시퀀스데이터로 다루어 global context를 잘 인코딩할 뿐 아니라 low-level의 CNN features를 Unet의 구조를 사용하여 잘 얻을 수 있습니다. 제안된 TransUNet은 다양한 데이터셋에서 더 좋은 결과를 보여줍니다.

 

구현한 코드는 아래 링크에서 확인할 수 있습니다.

github.com/Beckschen/TransUNet

 

Beckschen/TransUNet

This repository includes the official project of TransUNet, presented in our paper: TransUNet: Transformers Make Strong Encoders for Medical Image Segmentation. - Beckschen/TransUNet

github.com