Rough Summary: Spatial Transformer Networks
TL;DR
- Spatial Transformer learns to transform the feature maps
- It can be inserted at any point in an existing CNN, and it doesn't increase the computation much
- It achieve SOTA in several tasks including noisy-MNIST
Background & Introduction
- CNN is not really translation-invariant. Max-pooling does not fully take care of that
- Max-pooling mitigates the problem of translation-invariance, but since the size of pooling (e.g., 2 x 2 pixels) is small, it is only realized when the network is deep
- Spatial Transformer (ST) is a dynamic mechanism that learns to perform proper transformation for given data
- <--> Pooling operations are fixed (not conditioned on data)
Model: Spatial Transformers
For a given network architecture (e.g., CNN), you can choose where to insert Spatial Transformer and split the network into U and V as shown in the figure.
Localisation net takes the feature map and spits out a few parameters for transformation. These parameters can be just the values of elements in a transformation matrix .
Grid generator defines the mapping from feature map U to V. Then Sampler transforms the input according to the Grid generator.
Brief description for each component in Spatial Transformer
- Localisation Network
- input: feature map
- output: transformation params that will be applied to the feature map (e.g., for an affine transformation, 6-dimentional) This network can take any form, such as CNN.
- Grid Generator
- input: params and feature map
- output: sampling grid
- Sampler
- input: feature map and sampling grid
- output: map sampled from the grid points
Features to note
- Spatial Transformer is self-contained module and can be dropped at any point of a CNN architecture.
- It is computationally very fast. No heavy overhead.
- It can be learned fully end-to-end fashion.
Experiments
There are interesting observations from several experiments.
Distorted MNIST
Equipping Spatial Transformer with FCN and CNN decreases the error rate. ![Table 1](/media/posts/spatial-transformer/spatial_transformer_tb1.png =500x) * (a) Input image, (b) The transformations predicted by ST, (c) The output of ST
One can see that the reasonable transformations are learned
NOTE: Spatial Transformer is applied to the input before the classification network (FCN or CNN).
Questions
- Why are transformed digits ((c) in Table 1) consistent with the way humans read those characters? (i.e., It is possible that numbers are consistently transformed to a different angle than how humans recognize, but why that's not the case?)
- Intuitively, because of the way to create the data (randomly draw rotation between -90 deg to 90 deg but not 360 deg)
Classifying a bird
They prepare a model that contains 2 spatial transformers in a CNN. Interestingly, the result shows that one of them learns to crop heads, and the other crops the body. They could observe similar behaviours for 4 transformers architecture as well. ![Table 3](/media/posts/spatial-transformer/spatial_transformer_tb3.png =500x) * First row: 2 ST-CNN, Second row: 4 ST-CNN
In the first row, one can see the red box (cropped by the first ST) crops bird's head and the green box crops its body.
Two digit MNIST addition
The input contains 2 digits in an image, and the model needs to answer the sum of those numbers. They fuse 2 Spatial Transformers in a Network as shown in the figure.
Having 2 Spatial Transformers drastically improved the accuracy compared with naive FCN and CNN. (They also tried fusing 1 Spatial Transformer, and the result was worse than the naive CNN, which makes sense.)
![Table 4](/media/posts/spatial-transformer/spatial_transformer_tb4.png =500x) As you can see from the figure, ST1 learns to crop one number and ST2 does the other.