Need help with recurrent-visual-attention?
Click the “chat” button below for chat support from the developer who created it, or find similar developers for support.

About the developer

kevinzakka
418 Stars 112 Forks MIT License 80 Commits 15 Opened issues

Description

A PyTorch Implementation of "Recurrent Models of Visual Attention"

Services available

!
?

Need anything else?

Contributors list

# 24,624
CSS
Jupyter...
python3
spatial...
76 commits
# 473,929
Python
pytorch
ram
1 commit
# 468,542
Python
pytorch
ram
1 commit

Recurrent Visual Attention

This is a PyTorch implementation of Recurrent Models of Visual Attention by Volodymyr Mnih, Nicolas Heess, Alex Graves and Koray Kavukcuoglu.

Drawing

Drawing

The Recurrent Attention Model (RAM) is a neural network that processes inputs sequentially, attending to different locations within the image one at a time, and incrementally combining information from these fixations to build up a dynamic internal representation of the image.

Model Description

In this paper, the attention problem is modeled as the sequential decision process of a goal-directed agent interacting with a visual environment. The agent is built around a recurrent neural network: at each time step, it processes the sensor data, integrates information over time, and chooses how to act and how to deploy its sensor at the next time step.

Drawing

  • glimpse sensor: a retina that extracts a foveated glimpse
    phi
    around location
    l
    from an image
    x
    . It encodes the region around
    l
    at a high-resolution but uses a progressively lower resolution for pixels further from
    l
    , resulting in a compressed representation of the original image
    x
    .
  • glimpse network: a network that combines the "what" (
    phi
    ) and the "where" (
    l
    ) into a glimpse feature vector w
    g_t
    .
  • core network: an RNN that maintains an internal state that integrates information extracted from the history of past observations. It encodes the agent's knowledge of the environment through a state vector
    h_t
    that gets updated at every time step
    t
    .
  • location network: uses the internal state
    h_t
    of the core network to produce the location coordinates
    l_t
    for the next time step.
  • action network: after a fixed number of time steps, uses the internal state
    h_t
    of the core network to produce the final output classification
    y
    .

Results

I decided to tackle the

28x28
MNIST task with the RAM model containing 6 glimpses, of size
8x8
, with a scale factor of
1
.

| Model | Validation Error | Test Error | |-------|------------------|------------| | 6 8x8 | 1.1 | 1.21 |

I haven't done random search on the policy standard deviation to tune it, so I expect the test error can be reduced to sub

1%
error. I'll be updating the table above with results for the
60x60
Translated MNIST,
60x60
Cluttered Translated MNIST and the new Fashion MNIST dataset when I get the time.

Finally, here's an animation showing the glimpses extracted by the network on a random batch at epoch 23.

Drawing

With the Adam optimizer, paper accuracy can be reached in ~160 epochs.

Usage

The easiest way to start training your RAM variant is to edit the parameters in

config.py
and run the following command:
python main.py

To resume training, run:

python main.py --resume=True

Finally, to test a checkpoint of your model that has achieved the best validation accuracy, run the following command:

python main.py --is_train=False

References

We use cookies. If you continue to browse the site, you agree to the use of cookies. For more information on our use of cookies please see our Privacy Policy.