A PyTorch Implementation of "Recurrent Models of Visual Attention"
This is a PyTorch implementation of Recurrent Models of Visual Attention by Volodymyr Mnih, Nicolas Heess, Alex Graves and Koray Kavukcuoglu.
The Recurrent Attention Model (RAM) is a neural network that processes inputs sequentially, attending to different locations within the image one at a time, and incrementally combining information from these fixations to build up a dynamic internal representation of the image.
In this paper, the attention problem is modeled as the sequential decision process of a goal-directed agent interacting with a visual environment. The agent is built around a recurrent neural network: at each time step, it processes the sensor data, integrates information over time, and chooses how to act and how to deploy its sensor at the next time step.
lfrom an image
x. It encodes the region around
lat a high-resolution but uses a progressively lower resolution for pixels further from
l, resulting in a compressed representation of the original image
phi) and the "where" (
l) into a glimpse feature vector w
h_tthat gets updated at every time step
h_tof the core network to produce the location coordinates
l_tfor the next time step.
h_tof the core network to produce the final output classification
I decided to tackle the
28x28MNIST task with the RAM model containing 6 glimpses, of size
8x8, with a scale factor of
| Model | Validation Error | Test Error | |-------|------------------|------------| | 6 8x8 | 1.1 | 1.21 |
I haven't done random search on the policy standard deviation to tune it, so I expect the test error can be reduced to sub
1%error. I'll be updating the table above with results for the
60x60Cluttered Translated MNIST and the new Fashion MNIST dataset when I get the time.
Finally, here's an animation showing the glimpses extracted by the network on a random batch at epoch 23.
With the Adam optimizer, paper accuracy can be reached in ~160 epochs.
The easiest way to start training your RAM variant is to edit the parameters in
config.pyand run the following command:
To resume training, run:
python main.py --resume=True
Finally, to test a checkpoint of your model that has achieved the best validation accuracy, run the following command:
python main.py --is_train=False