Paper Information
Citation: Sundaramoorthy, C., Kelvin, L. Z., Sarin, M., & Gupta, S. (2021). End-to-End Attention-based Image Captioning. arXiv preprint arXiv:2104.14721. https://doi.org/10.48550/arXiv.2104.14721
Publication: arXiv 2021 (preprint)
Note: This is an arXiv preprint and has not undergone formal peer review.
What kind of paper is this?
This is a Methodological Paper ($\Psi_{\text{Method}}$). It proposes a novel architectural approach to molecular image translation by replacing the standard CNN-based encoder with a Vision Transformer (ViT). The authors validate this method through comparative benchmarking against standard CNN+RNN baselines (e.g., ResNet+LSTM) and provide optimizations for inference speed.
What is the motivation?
- Problem: Existing molecular translation methods (extracting chemical structure from images into computer-readable InChI format) rely heavily on rule-based systems or CNN+RNN architectures.
- Gap: These current approaches often underperform when handling noisy images (common in scanned old journals) or images with few distinguishable features.
- Need: There is a significant need in drug discovery to digitize and analyze legacy experimental data locked in image format within scientific publications.
What is the novelty here?
- ViT Encoder: The primary contribution is the use of a completely convolution-free Vision Transformer (ViT) as the encoder, allowing the model to utilize long-range dependencies among image patches from the very beginning via self-attention.
- End-to-End Transformer: The architecture is a pure Transformer (Encoder-Decoder), treating the molecular image similarly to a sequence of tokens (patches).
- Inference Caching: The authors implement a specific caching strategy for the decoder to avoid recomputing embeddings for previously decoded tokens, reducing the time complexity of the decoding step.
What experiments were performed?
- Baselines: The model was compared against:
- Standard CNN + RNN
- ResNet (18, 34, 50) + LSTM with attention
- Ablation Studies: Experiments were conducted varying the number of transformer layers (3, 6, 12, 24) and image resolution (224x224 vs 384x384).
- Dataset: A large combined dataset was used, including Bristol Myers Squibb data, SMILES, GDB-13, and synthetically augmented images containing noise and artifacts.
- Metric: Performance was evaluated using Levenshtein distance, which measures the number of edits required to convert the predicted string to the ground truth.
What outcomes/conclusions?
- Superior Performance: The proposed 24-layer ViT model (input size 384) achieved the lowest Levenshtein distance of 6.95, significantly outperforming the ResNet50+LSTM baseline (7.49) and the standard CNN+RNN (103.7).
- Importance of Depth: Increasing the number of layers had a strong positive impact, with the 24-layer model becoming competitive with current approaches.
- Robustness: The model performed well on datasets with low distinguishable features and noise, suggesting high robustness.
- Efficiency: The proposed caching optimization reduced the decoding time complexity per timestep from $O(MN^2 + N^3)$ to $O(MN + N^2)$.
Reproducibility Details
Data
The model was trained on a combined dataset randomly split into 70% training, 10% test, and 20% validation.
| Dataset | Description | Notes |
|---|---|---|
| Bristol Myers Squibb | ~2.4 million synthetic images with InChI labels. | Provided by BMS global biopharmaceutical company. |
| SMILES | Kaggle contest data converted to InChI. | Images generated using RDKit. |
| GDB-13 | Subset of 977 million small organic molecules (up to 13 atoms). | Converted from SMILES using RDKit. |
| Augmented Images | Synthetic images with salt/pepper noise, dropped atoms, and bond modifications. | Used to improve robustness against noise. |
Algorithms
- Training Objective: Cross-entropy loss minimization.
- Inference Decoding: Autoregressive decoding predicting the next character of the InChI string.
- Positional Encoding: Standard sine and cosine functions of different frequencies.
- Optimization:
- Caching: Caches the output of each layer during decoding to avoid recomputing embeddings for already decoded tokens.
- JIT: PyTorch JIT compiler used for graph optimization.
- Self-Critical Training: Finetuning performed using self-critical sequence training (SCST).
Models
- Encoder (Vision Transformer):
- Input: Flattened 2D patches of the image. Patch size: $16 \times 16$.
- Projection: Trainable linear projection to latent vector size $D$.
- Structure: Alternating layers of Multi-Head Self-Attention (MHSA) and MLP blocks.
- Decoder (Vanilla Transformer):
- Input: Tokenized InChI string + sinusoidal positional embedding.
- Vocabulary: 275 tokens (including
<SOS>,<PAD>,<EOS>).
- Hyperparameters (Best Model):
- Image Size: $384 \times 384$.
- Layers: 24.
- Feature Dimension: 512.
- Attention Heads: 12.
- Optimizer: Adam.
- Learning Rate: $3 \times 10^{-5}$ (decayed by 0.5 in last 2 epochs).
- Batch Size: Varied [64-512].
Evaluation
- Primary Metric: Levenshtein Distance (lower is better).
| Model | Image Size | Layers | Levenshtein Dist. |
|---|---|---|---|
| Standard CNN+RNN | 224 | 3 | 103.7 |
| ResNet50 + LSTM | 224 | 5 | 7.49 |
| ViT Transformers (Best) | 384 | 24 | 6.95 |
Hardware
- System: 70GB GPU system.
- Framework: PyTorch and PyTorch Lightning.
Citation
@misc{sundaramoorthyEndtoEndAttentionbasedImage2021,
title = {End-to-{{End Attention-based Image Captioning}}},
author = {Sundaramoorthy, Carola and Kelvin, Lin Ziwen and Sarin, Mahak and Gupta, Shubham},
year = 2021,
month = apr,
number = {arXiv:2104.14721},
eprint = {2104.14721},
primaryclass = {cs},
publisher = {arXiv},
doi = {10.48550/arXiv.2104.14721},
archiveprefix = {arXiv}
}