images are from https://zh-v2.d2l.ai/
an architecture commonly used in NLP and other types of tasks
Encoder: take raw input and represent the input as tensors after processing(could be word2vec, neural layers, attention...)
Decoder: mainly for outputting the result to desire form([0, 1], probability distribution, classification, etc)
xxxxxxxxxx
181from torch import nn
2
3class Encoder(nn.Module):
4 def __init__(self, **kwargs):
5 super(Encoder, self).__init__(**kwargs)
6
7 def forward(self, X, *args):
8 raise NotImplementedError
9
10class Decoder(nn.Module):
11 def __init__(self, **kwargs):
12 super(Decoder, self).__init__(**kwargs)
13
14 def init_state(self, encoder_outputs, *args):
15 raise NotImplementedError
16
17 def forward(self, X, state):
18 raise NotImplementedError
A specific type of tasks whose input and output are both sequences of any length
Ex. Machine Translation
Common arch of seq2seq models:
Machine Translation using RNN
BLEU(Bilingual Evaluation Understudy) for machine translation
n-gram
accuracyAttention Mechanism, KVQ
Key
: what is presentedValue
: sensory inputs(?)Query
: what we are interestedKey
sAttention Score,
model the relationship(importance, similarity) of Keys
& Querys
Kernel Regression
Additive Attention
Scaled Dot-Product Attention
Query
is decoder's input, Key
& Value
are both encoders output (final hidden state) Query
and is Key-Value
A commonly used position encoding method is using these and
Multi-head Attention aims to capture different "relationships" between Query
and Key
using multiple parallel attention layers and concat them to get the final result.
Mathematically:
xxxxxxxxxx
321class MultiHeadAttention(nn.Module):
2 def __init__(self, key_size, query_size, value_size, num_hiddens,
3 num_heads, dropout, bias=False, **kwargs):
4 super(MultiHeadAttention, self).__init__(**kwargs)
5 self.num_heads = num_heads
6 self.attention = d2l.DotProductAttention(dropout)
7 self.W_q = nn.Linear(query_size, num_hiddens, bias=bias)
8 self.W_k = nn.Linear(key_size, num_hiddens, bias=bias)
9 self.W_v = nn.Linear(value_size, num_hiddens, bias=bias)
10 self.W_o = nn.Linear(num_hiddens, num_hiddens, bias=bias)
11
12 def forward(self, queries, keys, values, valid_lens):
13 # assuming num_queries = num_keys = num_values
14
15 # initial queries:
16 # (batch_size, num_queries, num_hiddens)
17 # transformed queries:
18 # (batch_size * num_heads, num_queries, num_hiddens/num_heads)
19 queries = transpose_qkv(self.W_q(queries), self.num_heads)
20 keys = transpose_qkv(self.W_k(keys), self.num_heads)
21 values = transpose_qkv(self.W_v(values), self.num_heads)
22
23 if valid_lens is not None:
24 valid_lens = torch.repeat_interleave(
25 valid_lens, repeats=self.num_heads, dim=0)
26
27 # (batch_size * num_heads, num_queries, num_hiddens/num_heads)
28 output = self.attention(queries, keys, values, valid_lens)
29
30 # (batch_size, num_queries, num_hiddens)
31 output_concat = transpose_output(output, self.num_heads)
32 return self.W_o(output_concat)
xxxxxxxxxx
181def transpose_qkv(X, num_heads):
2 # (batch_size, num_queries, num_hiddens)
3
4 # (batch_size, num_queries, num_heads, num_hiddens/num_heads)
5 X = X.reshape(X.shape[0], X.shape[1], num_heads, -1)
6
7 # (batch_size, num_heads, num_queries, num_hiddens/num_heads)
8 X = X.permute(0, 2, 1, 3)
9
10 # (batch_size * num_heads, num_queries, num_hiddens/num_heads)
11 return X.reshape(-1, X.shape[2], X.shape[3])
12
13
14def transpose_output(X, num_heads):
15 """ reverse `transpose_qkv` """
16 X = X.reshape(-1, num_heads, X.shape[1], X.shape[2])
17 X = X.permute(0, 2, 1, 3)
18 return X.reshape(X.shape[0], X.shape[1], -1)
Annotated graph