Token Sparsification for Faster Medical Image Segmentation
Can we use sparse tokens for dense prediction, e.g., segmentation? Although token sparsification has been applied to Vision Transformers (ViT) to accelerate classification, it is still unknown how to perform segmentation from sparse tokens. To this end, we reformulate segmentation as a sparse encodi...
Saved in:
Main Authors | , , , , , |
---|---|
Format | Journal Article |
Language | English |
Published |
11.03.2023
|
Subjects | |
Online Access | Get full text |
Cover
Loading…
Summary: | Can we use sparse tokens for dense prediction, e.g., segmentation? Although
token sparsification has been applied to Vision Transformers (ViT) to
accelerate classification, it is still unknown how to perform segmentation from
sparse tokens. To this end, we reformulate segmentation as a sparse encoding ->
token completion -> dense decoding (SCD) pipeline. We first empirically show
that naively applying existing approaches from classification token pruning and
masked image modeling (MIM) leads to failure and inefficient training caused by
inappropriate sampling algorithms and the low quality of the restored dense
features. In this paper, we propose Soft-topK Token Pruning (STP) and
Multi-layer Token Assembly (MTA) to address these problems. In sparse encoding,
STP predicts token importance scores with a lightweight sub-network and samples
the topK tokens. The intractable topK gradients are approximated through a
continuous perturbed score distribution. In token completion, MTA restores a
full token sequence by assembling both sparse output tokens and pruned
multi-layer intermediate ones. The last dense decoding stage is compatible with
existing segmentation decoders, e.g., UNETR. Experiments show SCD pipelines
equipped with STP and MTA are much faster than baselines without token pruning
in both training (up to 120% higher throughput and inference up to 60.6% higher
throughput) while maintaining segmentation quality. |
---|---|
DOI: | 10.48550/arxiv.2303.06522 |