Implementation of Slot Attention from GoogleAI
Implementation of Slot Attention from the paper 'Object-Centric Learning with Slot Attention' in Pytorch. Here is a video that describes what this network can do.
Update: The official repository has been released here
$ pip install slot_attention
import torch from slot_attention import SlotAttentionslot_attn = SlotAttention( num_slots = 5, dim = 512, iters = 3 # iterations of attention, defaults to 3 )
inputs = torch.randn(2, 1024, 512) slot_attn(inputs) # (2, 5, 512)
After training, the network is reported to be able to generalize to slightly different number of slots (clusters). You can override the number of slots used by the
num_slotskeyword in forward.
slot_attn(inputs, num_slots = 8) # (2, 8, 512)
@misc{locatello2020objectcentric, title = {Object-Centric Learning with Slot Attention}, author = {Francesco Locatello and Dirk Weissenborn and Thomas Unterthiner and Aravindh Mahendran and Georg Heigold and Jakob Uszkoreit and Alexey Dosovitskiy and Thomas Kipf}, year = {2020}, eprint = {2006.15055}, archivePrefix = {arXiv}, primaryClass = {cs.LG} }