images are from https://zh-v2.d2l.ai/
an architecture commonly used in NLP and other types of tasks
Encoder: take raw input and represent the input as tensors after processing(could be word2vec, neural layers, attention...)
Decoder: mainly for outputting the result to desire form([0, 1], probability distribution, classification, etc)
xxxxxxxxxx181from torch import nn23class Encoder(nn.Module):4 def __init__(self, **kwargs):5 super(Encoder, self).__init__(**kwargs)67 def forward(self, X, *args):8 raise NotImplementedError910class Decoder(nn.Module):11 def __init__(self, **kwargs):12 super(Decoder, self).__init__(**kwargs)1314 def init_state(self, encoder_outputs, *args):15 raise NotImplementedError1617 def forward(self, X, state):18 raise NotImplementedError
A specific type of tasks whose input and output are both sequences of any length
Ex. Machine Translation
Common arch of seq2seq models:
Machine Translation using RNN
BLEU(Bilingual Evaluation Understudy) for machine translation
n-gram accuracyAttention Mechanism, KVQ
Key: what is presentedValue: sensory inputs(?)Query: what we are interestedKeysAttention Score,
model the relationship(importance, similarity) of Keys & Querys
Kernel Regression
Additive Attention
Scaled Dot-Product Attention
Query is decoder's input, Key & Value are both encoders output (final hidden state) Query and is Key-Value
A commonly used position encoding method is using these and
Multi-head Attention aims to capture different "relationships" between Query and Key using multiple parallel attention layers and concat them to get the final result.
Mathematically:
xxxxxxxxxx321class MultiHeadAttention(nn.Module):2 def __init__(self, key_size, query_size, value_size, num_hiddens,3 num_heads, dropout, bias=False, **kwargs):4 super(MultiHeadAttention, self).__init__(**kwargs)5 self.num_heads = num_heads6 self.attention = d2l.DotProductAttention(dropout)7 self.W_q = nn.Linear(query_size, num_hiddens, bias=bias)8 self.W_k = nn.Linear(key_size, num_hiddens, bias=bias)9 self.W_v = nn.Linear(value_size, num_hiddens, bias=bias)10 self.W_o = nn.Linear(num_hiddens, num_hiddens, bias=bias)1112 def forward(self, queries, keys, values, valid_lens):13 # assuming num_queries = num_keys = num_values14 15 # initial queries:16 # (batch_size, num_queries, num_hiddens)17 # transformed queries:18 # (batch_size * num_heads, num_queries, num_hiddens/num_heads)19 queries = transpose_qkv(self.W_q(queries), self.num_heads)20 keys = transpose_qkv(self.W_k(keys), self.num_heads)21 values = transpose_qkv(self.W_v(values), self.num_heads)2223 if valid_lens is not None:24 valid_lens = torch.repeat_interleave(25 valid_lens, repeats=self.num_heads, dim=0)2627 # (batch_size * num_heads, num_queries, num_hiddens/num_heads)28 output = self.attention(queries, keys, values, valid_lens)2930 # (batch_size, num_queries, num_hiddens)31 output_concat = transpose_output(output, self.num_heads)32 return self.W_o(output_concat)xxxxxxxxxx181def transpose_qkv(X, num_heads):2 # (batch_size, num_queries, num_hiddens)3 4 # (batch_size, num_queries, num_heads, num_hiddens/num_heads)5 X = X.reshape(X.shape[0], X.shape[1], num_heads, -1)67 # (batch_size, num_heads, num_queries, num_hiddens/num_heads)8 X = X.permute(0, 2, 1, 3)910 # (batch_size * num_heads, num_queries, num_hiddens/num_heads)11 return X.reshape(-1, X.shape[2], X.shape[3])121314def transpose_output(X, num_heads):15 """ reverse `transpose_qkv` """16 X = X.reshape(-1, num_heads, X.shape[1], X.shape[2])17 X = X.permute(0, 2, 1, 3)18 return X.reshape(X.shape[0], X.shape[1], -1)
Annotated graph
