I have been trying to implement a Transformer architecture using PyTorch by following the Attention Is All You Need paper as well as the The Annotated Transformer blog post to compare my code with theirs. And I noticed that in their implementation of the Multi-Head Attention they have used three nn.Linear(d_model, d_model)
to project the input of the encoder before splitting these projections into (n_heads, d_k)
matrices for the attention. But as my understanding of the paper goes, we need to have n_heads
of nn.Linear(d_model, d_k)
for each of the queries, keys and values as we can see in the Multi-Head Attention's diagram here from the paper:
We clearly see as many nn.Linear
layers as there are of heads. As well as the explanation of the authors:
Each $head_{i}$ uses $W_{i}^{Q}$, $W_{i}^{K}$ and $W_{i}^{V}$. So in my implementation I did this:
class MultiHeadedAttention(nn.Module):
def __init__(self, d_model=512, h=8):
super(MultiHeadedAttention, self).__init__()
self.d_model = d_model
self.h = h
self.d_k = d_model // h
self.query_linears = nn.ModuleList([nn.Linear(d_model, self.d_k) for i in range(h)])
self.key_linears = nn.ModuleList([nn.Linear(d_model, self.d_k) for i in range(h)])
self.value_linears = nn.ModuleList([nn.Linear(d_model, self.d_k) for i in range(h)])
self.projection_layer = nn.Linear(h * self.d_k, d_model)
def forward(self, Q, K, V, mask=None):
batch_size = Q.size(0)
queries = torch.cat([linear(Q).view(batch_size, 1, -1, self.d_k) for linear in self.query_linears], dim=1)
keys = torch.cat([linear(K).view(batch_size, 1, -1, self.d_k) for linear in self.key_linears], dim=1)
values = torch.cat([linear(V).view(batch_size, 1, -1, self.d_k) for linear in self.value_linears], dim=1)
x = scaled_dot_product_attention(queries, keys, values, mask)
x = x.transpose(1, 2)
x = x.contiguous()
x = x.view(batch_size, -1, self.h * self.d_k)
x = self.projection_layer(x)
return x
But I'm surely missing a key piece of understanding. And I'd be really grateful if someone can point it out to me.
Thank you.