https://d2l.ai/chapter_attention-mechanisms-and-transformers/vision-transformer.html
Hello everyone,
Am I wrong thinking that there is something wrong with this code?
class ViTBlock(nn.Module):
def __init__(self, num_hiddens, norm_shape, mlp_num_hiddens,
num_heads, dropout, use_bias=False):
super().__init__()
self.ln1 = nn.LayerNorm(norm_shape)
self.attention = d2l.MultiHeadAttention(num_hiddens, num_heads,
dropout, use_bias)
self.mlp = ViTMLP(mlp_num_hiddens, num_hiddens, dropout)
def forward(self, X, valid_lens=None):
X = self.ln1(X)
return X + self.mlp(self.ln2(
X + self.attention(X, X, X, valid_lens)))
More specifically the forward
method.
Based on the schema in Fig.11.8.1. we take input, normalize it but keep the original one, which later we will sum with the output from Multi-head attention later and the same for block with MLP inside.
But in the code from forward method:
def forward(self, X, valid_lens=None):
X = self.ln1(X)
return X + self.mlp(self.ln2(
X + self.attention(X, X, X, valid_lens)))
it looks like in the first line the original input is rewritten with the output from ln1
and then itâs used in later operations.
My gut feeling tells me that it should be something like this:
def forward(self, X, valid_lens=None):
X_norm = self.ln1(X)
X = X + self.attention(X_norm, X_norm, X_norm, valid_lens)
X_norm = self.ln2(X)
return X + self.mlp(X_norm)
Maybe I am missing somethingâŚ
I think you are right. For me it wouldnât converge (loss > 2.0, accuracy ~ 0.1), although I had to modify the code for an older torch version.
Then with the fixed forward function it works well, accuracy ~ 0.7 after 10 epochs.
Excuse me, can CNN be added to ViT? How do we modify the codes?
For the record, this issue this fixed.
The forward pass of vision Transformers below is straightforward. First, input images are fed into an
PatchEmbedding
instance, whose output is concatenated with the â<cls>
â token embedding. They are summed with learnable positional embeddings before dropout. Then the output is fed into the Transformer encoder that stacksnum_blks
instances of theViT-Block
class. Finally, the representation of the â<cls>
â token is projected by the network head.
How could we be certain that what has been learned in these positional embeddings is truly about position?
it seems like some mistakes are in Fig-11.8.1 via code block
a. cls token should be placed before patch embedding , instead places it after patch embedding .
b. â+â operation should be fed by âpos embeddingâ and âthe result of concating cls token and patch embeddingâ
I try to use Adam instead of SGD and then the loss keeps around 2.5 with an accurate of 10%. Momentum performs the same. Is the Moving average of gradient keeps the model in a local minimal. But momentum is a method to escape saddle, why here it performs on the contrary.