Learned Thresholds Token Merging and Pruning for Vision Transformers

How does LTMP combine token merging and token pruning?


LTMP adds learned threshold masking modules which learn a threshold for both pruning and merging. First token pairs that have a similarity score above the merging threshold are merged and then tokens with a importance score below the pruning threshold are pruned.

How many learnable parameters does LTMP introduce?


LTMP introduces only 2 learnable parameters per transformer block.

The learnable parameters are the thresholds, one for merging and one for pruning.

What loss function is used to train LTMP?


L=LCE+λ(rtargetrFLOPs)2\mathcal{L} = \mathcal{L}_{CE} + \lambda(r_{target} - r_{FLOPs})^2 with rFLOPsl=1L1L(2mˉl1nd2+(mˉl1n)2d+4mˉlnd26nd2+n2d)r_{\text{FLOPs}} \approx{} \sum_{l=1}^L \frac{1}{L}\left(\frac{2\bar{\mathbf{m}}^{l-1}nd^2 + (\bar{\mathbf{m}}^{l-1}n)^2d + 4\bar{\mathbf{m}}^{l}nd^2}{6nd^2 + n^2d}\right) where they denote ϕmodule(n,d)\phi_{\text{module}}(n,d) as a function that calculates the FLOPs of a module based on the number of tokens nn and the embedding dimension dd. mˉl=1ni=1nmil\bar{\mathbf{m}}^l = \frac{1}{n}\sum_{i=1}^n \mathbf{m}^l_i is the percentage of input tokens that are kept after the ll-th threshold masking operation and mˉ0=1\bar{\mathbf{m}}^0 = 1.

How do the learned threshold masks during LTMP mimic the effect of dropping tokens?


The attention function is modified such that it corresponds to attention as if it was only applied to the tokens that are not merged or pruned. Attention_with_mask(Q,K,V,m)=SV\operatorname{Attention\_with\_mask}(\mathbf{Q}, \mathbf{K}, \mathbf{V}, \mathbf{m}) = \mathbf{S}\mathbf{V} where, Sij=exp(Aij)mjk=1Nexp(Aik)mk,1i,j,kn\mathbf{S}_{ij} = \frac{\exp(\mathbf{A}_{ij})\mathbf{m}_{j}}{\sum_{k=1}^N\exp(\mathbf{A}_{ik})\mathbf{m}_{k}}, 1\le i,j,k\le n and, A=QKT/dkRn×n\mathbf{A} = \mathbf{Q}\mathbf{K}^T/\sqrt{d_k} \in \mathbb{R}^{n\times n}

What does LTMP use as the importance score for pruning?


LTMP uses the mean column attention score si=1hnj=1hk=1nSjkis_i = \frac{1}{h \cdot n}\sum_{j=1}^h \sum_{k=1}^n S_{jki} which represents the attention xix_i receives.

What does the threshold masking module in LTMP look like?


M(sil,θl)={1,if sil>θl0,otherwiseM(\mathbf{s}^l_i, \theta^l) = \begin{cases} 1, &\text{if }\mathbf{s}^l_i > \theta^l\\ 0, &\text{otherwise} \end{cases} where θ\theta is the learned threshold. To make the threshold differentiable during backpropagation it is estimated using a straight-through estimator in the backward pass. M(sil,θl)=σ(silθlτ)M(\mathbf{s}^l_i, \theta^l) = \sigma\left(\frac{\mathbf{s}^l_i - \theta^l}{\tau}\right)

Draw an overview of the LTMP framework.


paste-b12d0399be9f473302709bc2ab178c68d6233c88.jpg

Machine Learning Research Flashcards is a collection of flashcards associated with scientific research papers in the field of machine learning. Best used with Anki or Obsidian. Edit MLRF on GitHub.