How does LTMP combine token merging and token pruning?
LTMP adds learned threshold masking modules which learn a threshold for both pruning and merging. First token pairs that have a similarity score above the merging threshold are merged and then tokens with a importance score below the pruning threshold are pruned.
How many learnable parameters does LTMP introduce?
LTMP introduces only 2 learnable parameters per transformer block.
The learnable parameters are the thresholds, one for merging and one for pruning.
What loss function is used to train LTMP?
L=LCE+λ(rtarget−rFLOPs)2
with rFLOPs≈∑l=1LL1(6nd2+n2d2mˉl−1nd2+(mˉl−1n)2d+4mˉlnd2) where they denote ϕmodule(n,d) as a function that calculates the FLOPs of a module based on the number of tokens n and the embedding dimension d.
mˉl=n1∑i=1nmil is the percentage of input tokens that are kept after the l-th threshold masking operation and mˉ0=1.
How do the learned threshold masks during LTMP mimic the effect of dropping tokens?
The attention function is modified such that it corresponds to attention as if it was only applied to the tokens that are not merged or pruned.
Attention_with_mask(Q,K,V,m)=SV
where,
Sij=∑k=1Nexp(Aik)mkexp(Aij)mj,1≤i,j,k≤n
and,
A=QKT/dk∈Rn×n
What does LTMP use as the importance score for pruning?
LTMP uses the mean column attention scoresi=h⋅n1∑j=1h∑k=1nSjki which represents the attention xi receives.
What does the threshold masking module in LTMP look like?
M(sil,θl)={1,0,if sil>θlotherwise
where θ is the learned threshold.
To make the threshold differentiable during backpropagation it is estimated using a straight-through estimator in the backward pass.
M(sil,θl)=σ(τsil−θl)
Draw an overview of the LTMP framework.
Machine Learning Research
Flashcards is a collection of flashcards associated with scientific
research papers in the field of machine learning. Best used with Anki or Obsidian.
Edit MLRF on GitHub.