Vision Transformer

https://d2l.ai/chapter_attention-mechanisms-and-transformers/vision-transformer.html

1 Like

Hello everyone,

Am I wrong thinking that there is something wrong with this code?

class ViTBlock(nn.Module):
    def __init__(self, num_hiddens, norm_shape, mlp_num_hiddens,
                 num_heads, dropout, use_bias=False):
        super().__init__()
        self.ln1 = nn.LayerNorm(norm_shape)
        self.attention = d2l.MultiHeadAttention(num_hiddens, num_heads,
                                            dropout, use_bias)
    self.mlp = ViTMLP(mlp_num_hiddens, num_hiddens, dropout)
def forward(self, X, valid_lens=None):
    X = self.ln1(X)
    return X + self.mlp(self.ln2(
        X + self.attention(X, X, X, valid_lens)))

More specifically the forward method.

Based on the schema in Fig.11.8.1. we take input, normalize it but keep the original one, which later we will sum with the output from Multi-head attention later and the same for block with MLP inside.

But in the code from forward method:

def forward(self, X, valid_lens=None):
    X = self.ln1(X)
    return X + self.mlp(self.ln2(
        X + self.attention(X, X, X, valid_lens)))

it looks like in the first line the original input is rewritten with the output from ln1 and then it’s used in later operations.

My gut feeling tells me that it should be something like this:

def forward(self, X, valid_lens=None):
    X_norm = self.ln1(X)
    X = X + self.attention(X_norm, X_norm, X_norm, valid_lens)
    X_norm = self.ln2(X)
    return X + self.mlp(X_norm)

Maybe I am missing something…

I think you are right. For me it wouldn’t converge (loss > 2.0, accuracy ~ 0.1), although I had to modify the code for an older torch version.
Then with the fixed forward function it works well, accuracy ~ 0.7 after 10 epochs.

Excuse me, can CNN be added to ViT? How do we modify the codes?