You are on page 1of 6

Natural Language Processing

Autumn 2021
Philippe Schlattner philippe.schlattner@inf.ethz.ch

Week 11 Exercises
Question 1: Transformers as n-gram models
For this problem, we consider language models parameterized by transformers, i.e., only
the decoder portion of the transformer architecture, equipped with self-attention. With a
transformer, we can only consider a fixed-size window of context. Note that this is different
than recurrent neural networks, which can in theory, encode information from any number
of previous steps.

(a) Recall that the transformer architecture uses scaled dot-product attention:

q>
t K
α(t) = softmax(score(qt , K)) = √ (1)
h
where for input xt ∈ Rh , i.e., a vector representation (embedding) of the token at
position t of input X, we define qt = Wq xt , K = WK X and h is the hidden state
dimension. In terms of the size of the context window d, what is the time and space
complexity of a transformer?
Hint: think about what the dimensions of the matrices Wq and Wk must be.

(b) In an n-gram language model, we typically use a small window size, e.g., d = 4 ↔
n = 5. Would you expect a transformer with a context window of size 4 to perform
better or worse (in terms of perplexity on a held-out test set) than an n-gram model
where n = 5? Does your answer change depending on the settings (e.g., amount of
data, number of learnable parameters in the transformer)? Explain your reasoning.
You may assume the n-gram model was estimated with smoothing, and so perplexity
will not diverge.

(a) The runtime and space complexity of a transformer in terms of its context window
d can be reduced to the runtime and space complexity of its attention mechanisms.
Thus, we are interested in the complexities of the context computation. Note, that
we consider the context window to be of size d, where the case d = n (i.e. attention
over the full input sequence of size n) is a special and the most commonly used case.

q>
t Kt
c(t) = α(t) ∗ VtT = softmax(score(qt , Kt )) ∗ VtT = softmax( √ ) ∗ VtT , t ∈ [1..n] (2)
h

In the case of d < n, the keys and values will vary according to the window location.
Therefore, we use a subscript for keys and values to adapt for this change in the
original Transformer-attention notation.

1
Furthermore, using a more common notation (with qt ∈R1∗h ):

(t) (t) qt KtT


c =α ∗ Vt = softmax(score(qt , Kt )) ∗ Vt = softmax( √ ) ∗ Vt , t ∈ [1..n] (3)
h

where qt , Kt and Vt are the query, keys and values respectively. For each position
t ∈ [1..n] in the input sequence the complexities can then be derived considering the
following matrix dimensions:

• X ∈ Rn∗h , the input to the attention


• Wq , Wk , Wv ∈ Rh∗h , projections to build query, keys and values
• qt ∈ R1∗h
• Kt , Vt ∈ Rd∗h

Following the notation in equation 3, we build queries, keys and values :

• Q = X ∗ Wq ∈ Rn∗h → O(n ∗ h2 )
• K = X ∗ Wk ∈ Rn∗h → O(n ∗ h2 )
• V = X ∗ Wv ∈ Rn∗h → O(n ∗ h2 )

The matrices qt , Kt and Vt can then be extracted by taking the respective slices in
Q, K and V . Then, we derive the complexities of the context computation:

• qt KtT ∈ R1∗d → O(h ∗ d)


qt KT
• √ t
h
∈ R1∗d → O(h ∗ d)
qt KtT
• softmax( √
h
) ∈ R1∗d → O(h ∗ d)
qt KtT
• softmax( √
h
) ∗ Vt ∈ R1∗h → O(h ∗ d)
qt KtT
• softmax( √
h
) ∗ Vt ∈ R1∗h , t ∈ [1..n] → O(n ∗ h ∗ d)

Complexity of linear transformation to get Q, K and V : O(n ∗ h2 ).


Complexity of the context computation is O(n ∗ h ∗ d).
Overall complexity is O(n ∗ h2 + n ∗ h ∗ d).

Note: Be aware that the notation in equation 1 omitted the softmax at the end, as it does
not change the complexity of the computation.

(b) In most settings, a transformer will likely give superior performance, but the answer to this
question depends largely on the number of learnable parameters in the transformer and
the amount of smoothing used in the n-gram model. For small datasets, we may expect
that both a transformer with a large number of learnable parameters and an n-gram model
without smoothing will perform poorly due to overfitting. As even a small transformer may

2
be too flexible in this setting, it is likely that an n-gram model with smoothing will perform
the best.
For larger datasets, we may expect a transformer to perform better as e.g., the use of
word embeddings can help its generalization abilities. However, with an infinite amount of
data, an n-gram model without smoothing should perfectly represent the distribution over
language (when only given 4 words of context); we should expect that a transformer can
only perform as well as an n-gram model in this case.

Question 2: Sequence-to-sequence Models


Consider a simple encoder–decoder network (without attention for now) applied to a sequence-
to-sequence task where both encoder and decoder are parameterized by recurrent neural
networks. Formally, the source sequence x = hx1 , . . . , xN i is encoded using the recursion:
(s) (s) (s)
h(s)
n = f (W1 hn−1 + W2 e(xn )) (4)
(s) (s) (s)
for predefined h0 and activation function f to produce representations h1 , . . . , hN . We
have written e(xn ) for the embedding of the input token xn . Similarly, let the decoder
output hidden states
(t) (t) (t)
h(t)
m = f (W1 hm−1 + W2 e(ym−1 )) (5)
(t) (s)
for target sequence y = hy1 , . . . , yM i where h0 = hN , the last hidden state output by
our encoder. We have written e(ym ) for the embedding of the input token ym . Let lm+1
represent the loss at position m + 1 of the target sequence.

(a) Compute the derivative ∂lm+1(s) of the loss with respect to the encoding at each source
∂hn
position (you can leave this in terms of f 0 and ∂l). In terms of n, m, N , give a bound
on the number of terms in your expression. Feel free to use big-O notation. Conclude
by giving the runtime of backpropagation on this network.
(b) Now consider the same model augmented with standard, softmax attention (only over
the encoder), i.e.,
(t) (t) (t) (t)
h(t)
m = f (W1 hm−1 + W2 e(ym−1 ) + W3 cm ) (6)
PN (s)
where cm = n=1 αm,n hn and the vector αm is calculated as
(t) (s) (s)
αm = Φ(hm−1 , [h1 , . . . , hN ]) (7)

where Φ maps the vectors into a probability distribution, e.g. an MLP with a final
softmax layer. Similarly to (a), compute the derivative ∂lm+1
(s) of the loss with respect
∂hn
to the encoding at each source position (you can leave this in terms of f 0 , ∂l and Φ0 ).
In terms of n, m, N , give a bound on the number of terms in the chain-rule expansion
of ∂lm+1
(s) . Feel free to use big-O notation. Conclude by giving the runtime of back-
∂hn
propagation on this network. What is the difference in runtime between performing
backpropagation on an attention-based network and a non-attention-based network?

3
(t) (t) (t) (s) (s)
∂lm+1 ∂lm+1 ∂hm ∂h2 ∂h1 ∂hN ∂hn+1
(a) (s) = (t) (t) ... (t) (s) ... (s) ... (s)
∂hn ∂hm ∂hm−1 ∂h1 ∂hN ∂hN −1 ∂hn

where
(t)
∂hm ∂f (t) (t) (t) (t)
(D1) (t) = (t) (W1 hm−1 + W2 e(ym−1 )) · W1 → O(1) terms w.r.t {m, n, N}
∂hm−1 ∂hm−1
(s)
∂hn ∂f (s) (s) (s) (s)
(D2) (s) = (s) (W1 hn−1 + W2 e(xn )) · W1 → O(1) terms w.r.t {m, n, N}
∂hn−1 ∂hn−1

Case D1 will be computed for every decoder state <= m. Thus O(m) terms.
Case D2 will be computed for every every encoder state >= n. Thus O(N − n) terms.
The total number of terms on the shortest path is O(N − n + m).
(t)
∂lm+1 ∂lm+1 ∂hm
(b) (s) = (t) (s)
∂hn ∂hm ∂hn
where
(t)
∂hm ∂ (t) (t) (t) (t)
(s)
= (s)
f (W1 hm−1 + W2 e(ym−1 ) + W3 cm ) (8)
∂hn ∂hn

= (s)
f (A + B + C) (9)
∂hn

= f 0 (A + B + C) ∗ (s)
[A + B + C] (10)
∂hn

= f 0 (A + B + C) ∗ [A + C] , since B is constant w.r.t h(s)
(s) n (11)
 ∂hn 
∂A ∂C
= f 0 (A + B + C) ∗ (s)
+ (s)
(12)
∂hn ∂hn
(t)
0 (t) (t) (t) (t) (t) ∂hm−1 (t) ∂cm
= f (W1 hm−1 + W2 e(ym−1 ) + W3 cm )[W1 (s)
+ W3 (s)
] (13)
∂hn ∂hn

(I) As for the first part of the outer multiplication and given
(t) (t)
• W1 hm−1 → O(1) terms
(t)
• W2 e(ym−1 ) → O(1) terms
(t)
• W3 cm → O(N ) terms

(t) (t) (t) (t)


f 0 (W1 hm−1 + W2 e(ym−1 ) + W3 cm ) (14)

contains O(N ) terms w.r.t {m, n, N}.

4
(II) As for the second term within the second part of the outer multiplication

N
∂cm ∂ ∂ X (s)
(s)
= α h(s)
(s) m,n n
+ (s)
αm,i hi (15)
∂hn ∂hn ∂hn i6=n
N
∂αm,n ∂ X (s)
= αm,n + h(s)
(s) n
+ (s)
αm,i hi (16)
∂hn ∂hn i6=n
N (s)
X ∂hi (s)
= Φm,n + Φ0m,n h(s)
n + [Φm,i (s)
+ Φ0m,i hi ] (17)
i6=n ∂hn
N (s) N
X ∂hi X (s)
= Φm,n + Φ0m,n h(s)
n + Φm,i (s)
+ Φ0m,i hi (18)
i6=n ∂hn i6=n

with

• Φm,n → O(1)
(s)
• Φ0m,n hn → O(1)
(s)
∂hi

PN 2
i6=n Φm,i (s) → O((N − n) ) if we add all the terms from the chain rule in
∂hn
each iteration, and O(N − n) if we reuse the terms by factorizing them. Note
that for i >= n the chain rule is the same as in Q2a, for the other cases it is 0.
(s)

PN 0
i6=n Φm,i hi → O(N )

Thus containing O(N ) terms w.r.t {m, n, N}.

(III) As for the first term within the second part of the outer multiplication

(t)
(t) ∂hm−1
W1 (s)
(19)
∂hn
We can expand the derivation term using recursion. While taking derivatives of
decoder states in this recursion we will add O(N ) additional terms at each step. While
taking derivatives of encoder states in this recursion we will add O(1) additional terms
at each step (same reasoning as in Q2a).

Overall, the total number of terms on the shortest path is O(N − n + m · N ).

Question 3: Computational Cost of a Attention Layer


In this exercise, we are going to compute and compare the runtime complexity of two mod-
els for neural machine translation: RNN Encoder-Decoder and RNN Encoder-Decoder with
attention as defined in the previous question.

5
(a) Consider the encoder decoder model defined in the previous question. Assume that a
L-layer LSTM is used as encoder so we rewrite eq. (4) as:
(s) (s) (s) (s) (s)
hnl = f (Wl hln−1 + Wl−1 hl−1
n ) (20)
(s) (s)
Where h0n = e(xn ) and hnl is the encoder hidden state in timestep n and layer l.
For the decoder a one layer LSTM is used as before:
(t) (t) (t)
h(t)
m = f (W2 hm−1 + W1 e(ym−1 )) (21)
(s) (t) (t)
Assume that the hidden and embedding vectors are in Rd and parameters Wl , W1 , W2 ∈
Rd×d . With the input length N and output length M , what would be the runtime
complexity of the RNN Encoder-Decoder model?
(b) Now we slightly change the decoder by adding the attention mechanism into it just like
part (b) of the previous question. We further assume that each unit of the decoder only
attends to the last layer embeddings given by the encoder. Using the same embedding
and hidden state size and input output sequence length, what would be the runtime
complexity of computing the attention weights?
(c) We define the maximum path length across time as the shortest path length between
the first encoder input and the last decoder output. More intuitively, it shows how
easily the decoder model can access “information” from any given point in the input
sequence. Compare this property in RNN Encoder-Decoder and RNN Encoder-Decoder
with attention models.

(a) In the encoder side in each iteration there are O(d2 ) hidden-to-hidden connections to
compute which in total will be O(d2 LN ) for L layers and the input sequence length
(s)
N to get the final hidden state hLN . Similarly, in the decoder there will be O(d2 )
hidden-to-hidden connections which in total will be O(d2 M ) for processing the whole
sequence. So in total the runtime complexity is O(d2 LN + d2 M )

(b) The computation of attention weights takes O(dN ) in each iteration. So the total
runtime complexity would be O(dN M ).

(c) In the standard RNN Encoder-Decoder the maximum path length is N + M , i.e. the
”information” traverses each encoder and decoder state. However, using the attention
the ”information” from any token is accessible at any position in decoding. Therefore,
RNN Encoder-Decoder with attention model has a constant maximum path length
which shows one benefit of adding the attention mechanism.

You might also like