You are on page 1of 41

AN ATTENTION MATRIX FOR EVERY

DECISION

Harsh Vishwakarma
21532
MTech, CSA

Deep Learning for NLP


E0-334
Table of Contents

Introduction

Optimus Transformer Interpretibility

Applications of JL lemma

Speeding up JL lemma
Table of Contents

Introduction

Optimus Transformer Interpretibility

Applications of JL lemma

Speeding up JL lemma
Introduction

What is the most important aspect of a model in the real-world


scenario
I The model needs to be interpretable
Introduction

What is the most important aspect of a model in the real-world


scenario
I The model needs to be interpretable
I Enhances the performances on binary and multi-label data
Introduction

What is the most important aspect of a model in the real-world


scenario
I The model needs to be interpretable
I Enhances the performances on binary and multi-label data
I The authors mainly focused on:
I a new technique that selects the most faithful attention-based
interpretation among the several ones that can be obtained by
combining different head, layer, and matrix operations.
Interpretability

What is interpretability
I A model’s ability to provide insights for its decisions or inner
working, whether intrinsically or not, is referred to as
interpretability.
I Complex models, such as transformers, cannot provide
interpretations out of the box, and therefore posthoc
techniques are typically applied. The representations of an
interpretation include, among others, rules, heatmaps, and
feature importance.
Interpretability of Transformer

How can we interpret the results generated by the transformer


I The most popular transformer-specific interpretability
approach is the use of self-attention scores
Interpretability of Transformer

How can we interpret the results generated by the transformer


I The most popular transformer-specific interpretability
approach is the use of self-attention scores
I We can also generate attention maps as to check which part
of the input get most attention for a particular input instance
Feature importance based methods

I We can consider techniques like Layer-wise Relevance


Propagation (LRP) to check the gradient flow during
backpropagation as how the updates are being made
corresponding to each feature
Feature importance based methods

I We can consider techniques like Layer-wise Relevance


Propagation (LRP) to check the gradient flow during
backpropagation as how the updates are being made
corresponding to each feature
I Some of the ready-to-use interpretations that use the similar
idea are LIME, IG, SHARP
How is interpretibility evaluated?

I Comprehensibility : calculates the percentage of non-zero


weights in an interpretation. The lower this number, the
easier for end users to comprehend the interpretation.
How is interpretibility evaluated?

I Comprehensibility : calculates the percentage of non-zero


weights in an interpretation. The lower this number, the
easier for end users to comprehend the interpretation.
I Faithfulness Score: eliminates the token with the highest
importance score from the examined instance and measures
how much the prediction changes. Higher changes signify
better interpretations.
Table of Contents

Introduction

Optimus Transformer Interpretibility

Applications of JL lemma

Speeding up JL lemma
Optimus Transformer Interpretibility

Objective: Given a transformer model f , and an input sequence x


= [t1 , . . . , tS ], consisting of S tokens ti , i = 1 . . . S, our goal
is to extract a local interpretation z = [w1 , . . . , wS ], where wi 
R signifies the influence of token ti on the model’s decision f (x),
based on the model’s self-attention scores.
Using Attention scores

We know that the Attention scores corresponding to each token


are generated as:
T
I A = softmax( Q.K
√ + mask) where ARSxS , S: length of
d
sequence
Using Attention scores

We know that the Attention scores corresponding to each token


are generated as:
T
I A = softmax( Q.K √ + mask) where ARSxS , S: length of
d
sequence
I To get beneficial scores for both polarities, the authors
consider interpretations, they removed the softmax function
and named that matrix as A∗ .
How Attention matrix is interpreted
Operations over Attention Matrices

Aggregation of Attention Matrix:The process involves


aggregating attention matrices across all heads within each
self-attention layer.

Head Operations:Common operations applied to the


attention matrices of each head. Averaging and summing
essentially give the same token importance order, differing
only in the magnitude of scores assigned to tokens
Operations over Attention Matrices

Final Interpretation Vector:


I Operations like ”From [CLS]” and ”To [CLS]” involve
extracting attention regarding the special [CLS] token that is
typically prepended in text classification tasks. This operation
considers the attention the [CLS] token gives and receives
from other tokens.
Selecting the Best Interpretation

I Select the best set of operations by iterating through different


combinations of operations across the layers and heads in the
transformer model.
I Find the most faithful interpretation for a single instance. It
iterates through various combinations of head, layer, and
matrix operations, and for each combination, it evaluates the
faithfulness using the metric. The combination with the
highest faithfulness score is selected as the best interpretation.
Lemma 3

Lemma: P kf (xi ) − f (xj )k22 ≤ (1 − )kxi − xj k22 ≤ n−2


 

Proof:
kv k22 ≤ (1 − )kuk22 k
 
P
kuk22 x
 
2
= ≤ (1 − )kuk2
p
=[x ≤ (1 − )p]
h i
= e −λx ≥ e −λ(1−)p ( for all λ ≥ 0)

By Markov’s inequality

E [x]
P[x ≥ a] ≤
a
Lemma 3

E e −λx
h i  
−λx −λ(1−)p
P e ≥e ≤ −λ(1−)p
e h i
2
p
Y E e −λxi
≤ (as xi ’s are i.i.d.)
e −λ(1−)p
i=1
 h 2
i p
E e −λxi
≤  −λ(1−) 
e
 p
1
≤ √
1 + 2λ · e −λ(1−)
( using the m.g.f. of the χ2 distribution )
Lemma 3


put λ = (optimal value after differentiating )
2(1 − )
≤ [(1 − )e  ]p/2
Using inequality log(1 − x) < −x − x 2 /2 and putting value of p,
≤ n−2
Combining the above lemmas we get,
2
P kv k22 ∈/ (1 − )kuk22 , (1 + )kuk22 ≤ 2
 
n
Now the above result is for any two pairs of points. We can use
Union Bound for all such pairs of points.
(n2)
X
P (Bounds fail for any pair of points ) ≤ P (Bounds fail for xi , xj )
i=1
n(n − 1) 2 1
≤ 2
≤1−
2 n n
Table of Contents

Introduction

Optimus Transformer Interpretibility

Applications of JL lemma

Speeding up JL lemma
Applications

I Low-Rank Matrix Approximation Using Random Projection


Applications

I Low-Rank Matrix Approximation Using Random Projection


I Approximate Nearest Neighbor Search
Low-Rank Matrix Approximation Using Random Projection
Consider the matrix X Rd×n where d≥n. The time complexity for
SVD of this matrix O(dn2 )

Papadimitriou : We can improve the time complexity to O(dnln(n)


using random projection.

2 Step process:

Step 1 : Find a smaller matrix Y by random projection


1
Rp×n 3 Y := √ U T X
p

URd×p is a random projection matrix whose elements are


independently drawn from a standard normal distribution
p ≥ c−2 ln(n) c > 0
Low-Rank Matrix Approximation Using Random Projection

Step 2 : SVD of Y
p
X
Y = AΛB T = λi ai biT
i=1

Rp×p 3A: left singular vectors of Y


Rn×p 3B: right singular vectors of Y
Rp×p 3Λ: Singular values of Y
SVD of Y is much faster than SVD of X as p<<d

The matrix X canbe approximated


 as its projection as:
d×n
Pp >
R 3 Xp ≈ X
e
i=1 b i b i
Low-Rank Matrix Approximation Using Random Projection

Lemma 1: Let the SVD of X be X = C ΣE > = di=1 σi c i e >


P
i
where C = [c , . . . , c ] ∈ Rd×d , E = [e , . . . , e ] ∈ Rn×d , and
 1 d  1 d
Σ = diag [σ1 , . . . , σd ]> ∈ Rd×d are the left singular vectors,
right singular vectors, and singular values, respectively. P Also, let
the SVD of X with top p singular values be X p := pi=1 σi c i e > i .
If p ≥ c ln(n)/2 , the singular values of Y are not much smaller
than the singular values of X , i.e.:
p
X p
X
λ2i ≥ (1 − ) σi2 = (1 − ) kX p k2F
i=1 i=1

where k · kF is the Frobenius norm.


Low-Rank Matrix Approximation Using Random Projection

Lemma 2: The low-rank approximation X


e approximates X well
enough:
2
X −X e p ≤ kX − X p k2 + 2 kX p k2
F F F
Low-Rank Matrix Approximation Using Random Projection

Lemma 2: The low-rank approximation X


e approximates X well
enough:
2
X −X e p ≤ kX − X p k2 + 2 kX p k2
F F F

Lemma 3: The time complexity of low-rank approximation (with


rank p ) of matrix X using random projection is O(dn ln(n)) which
is much better than the complexity of SVD on X which is O dn2 .

Approximate Nearest Neighbor Search

I Let Zwd denote the set of d-dimensional integer vectors with w


possible values; i.e Z2d := {0, 1}d is the set of binary vectors.
I kxi − xj kH denote the Hamming distance between two binary
vectors xi and xj .
I Instead of l2 distance, we consider the l1 distance or Hamming
distance.
Approximate Nearest Neighbor Search
Theorem Random projection onto hypercube (Kushilevitz 2000).
Consider a binary vector x ∈ Zd2 which is projected as
f (x) = U > x mod 2, with a random binary projection matrix
R ∈ {0, 1}d×p , where p = O ln(n)/2 ) (usually p  d). The
elements of U are i.i.d.
 with Bernoulli distribution having
probability ξ = 2 /` to be one and probability (1 − ξ) to be zero.
` `
if kx i − x j kH < =⇒ kf (x i ) − f (x j )kH < (1 + )pξ
4 4

` ` kf (x i ) − f (x j )kH
if ≤ kx i − x j kH ≤ =⇒ (1−)pξ ≤ < (1+)pξ
4 2 kx i − x j kH

` `
if kx i − x j kH > =⇒ kf (x i ) − f (x j )kH > (1 − )pξ
2 2

forall  
4
xi, xj ∈ Zd2 , with probability at least 1 − e −c p where c > 0.
Approximate Nearest Neighbor Search using random
projection

I Consider a dataset X := {xi R d }ni=1


i
. The nearest neighbor
search problem refers to finding the closest point of dataset
x ∗ X to a query point qR d
I Basic approach is to traverse all points and check naively. It’s
time and space complexities are both O(nd) which are not
good.
I There is an algorithm for nearest neighbor search (Meiser,
1993) with time complexity O(poly (d, ln(n))). However, the
space complexity of this algorithm is O(nd ).
Approximate Nearest Neighbor Search using random
projection

I Approximate nearest neighbor search returns a point x ∗  X,


which satisfies:

kq − x ∗ k22 ≤ (1 + )(minxX kq − xk2 )

I We can relax this definition if we take the acceptable distance


r from the user:
kq − x ∗ k22 ≤ (1 + )r
If no such point is found in X, null is returned.
Main Idea
I The idea is that the smallest distance r can be found by doing
binary search on ` as 0 ≤ ` ≤ d.
I In the binary search for finding the smallest distance r, we
should try several distances. For every distance, we perform k
independent random projections onto k hypercubes, and then
we select one of these random projections randomly.
I In every random projection, the points X are projected onto a
random p-dimensional hypercube to have {f (xi)Z2p }ni=1i

I In the low-dimensional projected subspace, the comparison of


points is very fast because :
 the subspace dimensionality p is much less than the original
dimensionality d
 calculation of Hamming distance is faster and easier than
Euclidean distance.
I Therefore, random projections onto hypercubes are very useful
for approximate nearest neighbor search.
Lemma

Lemma The above algorithm for approximate nearest neighbor


search is correct with probability (1 − δ) where 0 < δ 1. The
time complexity of the above algorithm is O d4 ln nδ ln(d) and
 4

its space complexity is O d 2 (c1 n ln(d))c2 / where c1 and c2 are
constants.
Table of Contents

Introduction

Optimus Transformer Interpretibility

Applications of JL lemma

Speeding up JL lemma
Speeding up the JL transform

I Given A, computing the matrix-vector product takes O(pd)


time. There has been some work in deriving distributions for
which the matrix-vector product can be computed in less than
O(pd) time
I There are two major lines of work.
 Fast Johnson Lindenstrauss Transform (FJLT), was
introduced by Ailon and Chazelle in 2006. This method allows
the computation of the matrix-vector product in just
d log d + p 2+γ for any constant γ > 0 where mapinng function
is Φ = k1 UHD.
 Sparse JL approach is to build a distribution supported over
matrices that are sparse.This method allows keeping only an ε
fraction of the entries in the matrix, which means the
computation can be done in just pdε time.
I Furthermore, if the vector has only b non-zero entries, the
Sparse JL takes time pbε, which may be much less than the
d log d time used by Fast JL.
Thank You!

You might also like