Vision Transformer

https://d2l.ai/chapter_attention-mechanisms-and-transformers/vision-transformer.html

1 Like

Hello everyone,

Am I wrong thinking that there is something wrong with this code?

class ViTBlock(nn.Module):
    def __init__(self, num_hiddens, norm_shape, mlp_num_hiddens,
                 num_heads, dropout, use_bias=False):
        super().__init__()
        self.ln1 = nn.LayerNorm(norm_shape)
        self.attention = d2l.MultiHeadAttention(num_hiddens, num_heads,
                                            dropout, use_bias)
    self.mlp = ViTMLP(mlp_num_hiddens, num_hiddens, dropout)
def forward(self, X, valid_lens=None):
    X = self.ln1(X)
    return X + self.mlp(self.ln2(
        X + self.attention(X, X, X, valid_lens)))

More specifically the forward method.

Based on the schema in Fig.11.8.1. we take input, normalize it but keep the original one, which later we will sum with the output from Multi-head attention later and the same for block with MLP inside.

But in the code from forward method:

def forward(self, X, valid_lens=None):
    X = self.ln1(X)
    return X + self.mlp(self.ln2(
        X + self.attention(X, X, X, valid_lens)))

it looks like in the first line the original input is rewritten with the output from ln1 and then it’s used in later operations.

My gut feeling tells me that it should be something like this:

def forward(self, X, valid_lens=None):
    X_norm = self.ln1(X)
    X = X + self.attention(X_norm, X_norm, X_norm, valid_lens)
    X_norm = self.ln2(X)
    return X + self.mlp(X_norm)

Maybe I am missing something…

I think you are right. For me it wouldn’t converge (loss > 2.0, accuracy ~ 0.1), although I had to modify the code for an older torch version.
Then with the fixed forward function it works well, accuracy ~ 0.7 after 10 epochs.

Excuse me, can CNN be added to ViT? How do we modify the codes?

For the record, this issue this fixed.

My solutions to the exs: 11.8

1 Like

The forward pass of vision Transformers below is straightforward. First, input images are fed into an PatchEmbedding instance, whose output is concatenated with the “<cls>” token embedding. They are summed with learnable positional embeddings before dropout. Then the output is fed into the Transformer encoder that stacks num_blks instances of the ViT-Block class. Finally, the representation of the “<cls>” token is projected by the network head.

How could we be certain that what has been learned in these positional embeddings is truly about position?