이전에 Vision Transformer의 Embedding까지 구현했다.
[ViT] Vision Transformer 구현 -1 Class Token, Position Embedding
이전에 이미지를 패치로 나누고 프로젝션 연산까지 진행하는 코드를 작성했다. [ViT] Vision Transformer 구현 -0 Linear Projection of Flattened Pathes Vision Transformer를 모듈별로 구현하며 궁금한 점을 기록했다
yeeca.tistory.com
class VisionTransformer_ ( nn . Module ):
def __init__ ( self , img_size , patch_size , embedd_dim ,):
super(). __init__ ()
self .img_size = img_size
self .p = patch_size
self .embedd_dim = embedd_dim
self .projection = nn.Conv2d( 3 ,embedd_dim,patch_size,patch_size)
self .class_token = nn.Parameter(torch.zeros( 1 , 1 , embedd_dim))
seq_length = (img_size // patch_size) ** 2 + 1
self .pos_embedding = nn.Parameter(torch.zeros( 1 , seq_length, embedd_dim).normal_(std= 0.02 ))
def _patch_embedd ( self , x ):
n,c,h,w = x.shape
n_h = h // self .p
n_w = w // self .p
# projection
x = self .projection(x)
x = x.reshape(n, self .embedd_dim,n_h*n_w)
x = x.permute( 0 , 2 , 1 )
# concat class token
ct = self .class_token.expand(n, -1 , self .embedd_dim)
x = torch.cat([ct,x],dim= 1 ) # (2,197,32)
# add pos
x = x+ self .pos_embedding
return x
def forward ( self , x ):
# embedding
x = self ._patch_embedd(x)
# encoder
# head
return x
Encoder는 NLP 분야에서 발표된 "Attention Is All You Need"(https://arxiv.org/abs/1706.03762 ) 논문의 Transformer Encoder랑 같다. Decoder는 사용하지 않으며 Head 부분에서 Classification이나 Objectdetection 등 Task에 따라 아키텍처를 수정한다.
Encoder 내부에서 사용되는 Normalization 방법은 Layer Normalization(LN)이다.
Batch Normalization(BN)을 사용하지 않는 이유는 BN은 배치 단위로 정규화를 진행하므로 배치 크기에 따라 성능이 좌우된다. LN은 채널 단위로 계산하므로 배치 크기에 덜 민감하다. 따라서 작은 크기의 배치에서 BN보다 상대적으로 안정적인 성능을 보일 수 있다. 또한 ViT는 패치로 나누어서 처리하는 구조이므로 채널 단위 정규화가 더 적합하다.
MSA는 Muti-head Self-Attention이다. 임베딩된 패치들은 LN 이후 Q, K, V로 나뉘어 Multi-head Attention을 수행한다. 그리고 Norm 이전의 Matrix(E : Embedd Matrix) 들과 Residual Connection으로 Add 연산을 진행하고 다시 LN 이후 MLP (Multi-Layer Perceptron) 블록에 입력되고 또 Residual Connection으로 Add 연산 진행한다. 이 과정을 L번 반복한다.
마지막 L 번째 연산이 끝난 z의 0번째 채널(Class 토큰 생성 후 Concat 한 부분)이 y 값이 된다.
pytorch에서 nn.MutiheadAttention 클래스를 제공하므로 일단 사용했다.
일단 이전까지 구현한 Embedding 직후 shape는 다음과 같다.
image_size = 224
img = torch.rand(( 2 , 3 , image_size, image_size)) # 3채널 컬러 이미지
patch_size = p = 16
embedding_size = 32
vit = VisionTransformer_(image_size,patch_size,embedding_size)
test = vit(img)
test.shape
>>>torch.Size([2, 197, 32])
위 선형대수 식에서 2)번만 구현하면 다음과 같다.
def MSA ( input , embedd_dim , num_heads , dropout ):
ln = nn.LayerNorm(embedd_dim)
msa = nn.MultiheadAttention(embedd_dim,num_heads,dropout,batch_first= True )
x = ln(input)
x,_ = msa(x,x,x,need_weights = False ) # q,k,v
x += input
print (_)
return x
test = MSA(test,embedding_size, 4 , 0.2 )
test.shape
>>>None
torch.Size([2, 197, 32])
입력과 출력 shape가 같기 때문에 Residual Connection이 가능하다. nn.MultiheadAttention에서 batch_first = True로 한 이유는 default shape는 (S, N, E) 이기 때문이다. batch_first = True로 하여 (N, S, E) shape로 맞춰준다. 아래는 공식 문서 내용이다.
batch_first – If True , then the input and output tensors are provided as (batch, seq, feature). Default: False (seq, batch, feature).
need_weights = False로 한 이유는 여기선 attention weight을 필요로 하지 않기 때문이다. default가 True이므로 False로 바꿔준다.
need_weights (bool ) – If specified, returns attn_output_weights in addition to attn_outputs . Default: True .
num_heads는 embedd_dim과 나눠서 나머지가 0이어야 한다.
좀 더 자세한 설명과 구현은 아래 블로거 분이 잘 설명 해주셨다.
[DNN] VIT(vision transformer) 리뷰 및 코드구현(CIFAR10) (ICLR2021)
Introduction 안녕하세요 pulluper입니다. 👏 이번 포스팅에서는 NLP에서 강력한 성능으로 기준이 된 Transformer (Self-Attention)을 vision task에 적용하여 sota(state-of-the-art)의 성능을 달성한 ICLR2021에 발표된 vi
csm-kr.tistory.com
공식 깃허브(https://github.com/google-research/vision_transformer/blob/main/vit_jax/models_vit.py#L64 )에 3) 번 식의 MLP에 gelu는 한 번 사용된다.
3)번 식을 구현하면 다음과 같다.
def MLP ( input , embedd_dim , mlp_dim , dropout ):
ln = nn.LayerNorm(embedd_dim)
fc1 = nn.Linear(embedd_dim,mlp_dim,bias = False )
gelu = nn.GELU()
fc2 = nn.Linear(mlp_dim, embedd_dim,bias = False )
drop = nn.Dropout(dropout)
x = gelu(fc1(ln(input)))
x = drop(x)
x = gelu(fc2(x))
x = drop(x)
x += input
return x
test = MLP(test,embedding_size, 64 , 0.2 )
test.shape
>>>torch.Size([2, 197, 32])
3) 번 연산 이후도 이전과 같은 shape이므로 이 과정을 L 번 반복할 수 있다.
2), 3)을 통합하여 인코더 클래스를 구현하면 다음과 같다.
class MLPBlock ( nn . Module ):
def __init__ ( self , hidden_dim , mlp_dim , dropout = False ):
super(). __init__ ()
self .ln = nn.LayerNorm(hidden_dim)
self .fc1 = nn.Linear(hidden_dim,mlp_dim,bias = False )
self .gelu = nn.GELU()
self .fc2 = nn.Linear(mlp_dim, hidden_dim,bias = False )
self .drop1 = nn.Dropout(dropout)
self .drop2 = nn.Dropout(dropout)
def forward ( self , input ):
x = self .gelu( self .fc1( self .ln(input)))
x = self .drop1(x)
x = self .gelu( self .fc2(x))
x = self .drop2(x)
x += input
return x
class EncoderBlock ( nn . Module ):
def __init__ ( self , hidden_dim , num_heads , mlp_dim , dropout ):
super(). __init__ ()
self .mlp = MLPBlock(hidden_dim,mlp_dim,dropout)
self .msa = nn.MultiheadAttention(hidden_dim,num_heads,dropout,batch_first= True )
self .ln = nn.LayerNorm(hidden_dim)
def forward ( self , input ):
# MSA
x = self .ln(input)
x,_ = self .msa(x,x,x,need_weights = False ) # q,k,v
x += input
# MLP
x = self .mlp(x)
return x
class Encoder ( nn . Module ):
def __init__ ( self , hidden_dim , num_heads , mlp_dim , num_layers , dropout ):
super(). __init__ ()
layers: OrderedDict[ str , nn.Module] = OrderedDict()
for i in range (num_layers):
layers[ f "encoder_layer_ {i} " ] = EncoderBlock(
hidden_dim,
num_heads,
mlp_dim,
dropout,
)
self .layers = nn.Sequential(layers)
def forward ( self , x ):
return self .layers(x)
encoder = Encoder(embedding_size, 8 , 64 , 5 , 0.1 )
test = encoder(test)
test.shape
>>> torch.Size([2, 197, 32])
hidden_dim은 지금까지 사용한 embedding_size와 같다. mlp_dim은 fc layer의 hidden_dim이다.
완성된 Vision Transformer 클래스는 다음과 같다.
class VisionTransformer ( nn . Module ):
def __init__ ( self , img_size , patch_size , embedd_dim , num_heads , mlp_dim , num_layers , head_dim , dropout , num_classes ):
super(). __init__ ()
self .img_size = img_size
self .p = patch_size
self .embedd_dim = embedd_dim
self .projection = nn.Conv2d( 3 ,embedd_dim,patch_size,patch_size)
self .class_token = nn.Parameter(torch.zeros( 1 , 1 , embedd_dim))
seq_length = (img_size // patch_size) ** 2 + 1
self .pos_embedding = nn.Parameter(torch.zeros( 1 , seq_length, embedd_dim).normal_(std= 0.02 ))
self .encoder = Encoder(embedd_dim,num_heads,mlp_dim,num_layers,dropout)
self .head = nn.Sequential(
nn.Linear(embedd_dim,head_dim,bias= False ),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(head_dim,num_classes,bias= False )
)
def _patch_embedd ( self , x ):
n,c,h,w = x.shape
n_h = h // self .p
n_w = w // self .p
# projection
x = self .projection(x)
x = x.reshape(n, self .embedd_dim,n_h*n_w)
x = x.permute( 0 , 2 , 1 )
# concat class token
ct = self .class_token.expand(n, -1 , self .embedd_dim)
x = torch.cat([ct,x],dim= 1 ) # (2,197,32)
# add pos
x = x+ self .pos_embedding
return x
def forward ( self , x ):
# embedding
x = self ._patch_embedd(x)
# encoder
x = self .encoder(x)
x = x[:, 0 ]
# head
x = self .head(x)
return x
vit = VisionTransformer(
img_size = 224 ,
patch_size = 16 ,
embedd_dim = 32 ,
num_heads = 8 ,
mlp_dim = 128 ,
num_layers = 5 ,
head_dim = 64 ,
dropout = 0.1 ,
num_classes = 10
)
re = vit(img)
re.shape
>>>torch.Size([2, 10])
댓글