Rough Summary: Spatial Transformer Networks

TL;DR

  • Spatial Transformer learns to transform the feature maps
  • It can be inserted at any point in an existing CNN, and it doesn't increase the computation much
  • It achieve SOTA in several tasks including noisy-MNIST

Background & Introduction

  • CNN is not really translation-invariant. Max-pooling does not fully take care of that
  • Max-pooling mitigates the problem of translation-invariance, but since the size of pooling (e.g., 2 x 2 pixels) is small, it is only realized when the network is deep
  • Spatial Transformer (ST) is a dynamic mechanism that learns to perform proper transformation for given data
    • <--> Pooling operations are fixed (not conditioned on data)

Model: Spatial Transformers

architecture

For a given network architecture (e.g., CNN), you can choose where to insert Spatial Transformer and split the network into U and V as shown in the figure.

Localisation net takes the feature map and spits out a few parameters θ\theta for transformation. These parameters can be just the values of elements in a transformation matrix AθA_{\theta}.
Grid generator defines the mapping from feature map U to V. Then Sampler transforms the input according to the Grid generator.

Brief description for each component in Spatial Transformer

  1. Localisation Network
  • input: feature map
  • output: transformation params that will be applied to the feature map (e.g., for an affine transformation, 6-dimentional) This network can take any form, such as CNN.
  1. Grid Generator
  • input: params and feature map
  • output: sampling grid
  1. Sampler
  • input: feature map and sampling grid
  • output: map sampled from the grid points

Features to note

  • Spatial Transformer is self-contained module and can be dropped at any point of a CNN architecture.
  • It is computationally very fast. No heavy overhead.
  • It can be learned fully end-to-end fashion.

Experiments

There are interesting observations from several experiments.

Distorted MNIST

Equipping Spatial Transformer with FCN and CNN decreases the error rate. ![Table 1](/media/posts/spatial-transformer/spatial_transformer_tb1.png =500x) * (a) Input image, (b) The transformations predicted by ST, (c) The output of ST

One can see that the reasonable transformations are learned

NOTE: Spatial Transformer is applied to the input before the classification network (FCN or CNN).

Questions

  • Why are transformed digits ((c) in Table 1) consistent with the way humans read those characters? (i.e., It is possible that numbers are consistently transformed to a different angle than how humans recognize, but why that's not the case?)
    • Intuitively, because of the way to create the data (randomly draw rotation between -90 deg to 90 deg but not 360 deg)

Classifying a bird

They prepare a model that contains 2 spatial transformers in a CNN. Interestingly, the result shows that one of them learns to crop heads, and the other crops the body. They could observe similar behaviours for 4 transformers architecture as well. ![Table 3](/media/posts/spatial-transformer/spatial_transformer_tb3.png =500x) * First row: 2 ST-CNN, Second row: 4 ST-CNN

In the first row, one can see the red box (cropped by the first ST) crops bird's head and the green box crops its body.

Two digit MNIST addition

The input contains 2 digits in an image, and the model needs to answer the sum of those numbers. They fuse 2 Spatial Transformers in a Network as shown in the figure.

Having 2 Spatial Transformers drastically improved the accuracy compared with naive FCN and CNN. (They also tried fusing 1 Spatial Transformer, and the result was worse than the naive CNN, which makes sense.)

![Table 4](/media/posts/spatial-transformer/spatial_transformer_tb4.png =500x) As you can see from the figure, ST1 learns to crop one number and ST2 does the other.

Links